use crate::ndarray_ext::{NdArrayView, NdArrayViewMut};
use crate::smallvec::SmallVec;
use crate::tensor::{Tensor, TensorInternal};
use crate::{Float, NdArray};
use std::any::type_name;
use std::fmt;
use std::marker::PhantomData;
use std::mem;
pub(crate) const NUM_MAX_OUTPUT: usize = 2;
pub(crate) const NUM_MAX_INPUT: usize = 4;
pub(crate) type InputArray<T> = SmallVec<[T; NUM_MAX_INPUT]>;
pub(crate) type OutputArray<T> = SmallVec<[T; NUM_MAX_OUTPUT]>;
pub(crate) type ComputeResult<'v, T> = Result<crate::ArrRepr<'v, T>, OpError>;
#[derive(Clone, Debug, PartialEq)]
pub enum OpError {
NdArrayError(String, ndarray::ShapeError),
IncompatibleShape(String),
TypeUnsupported(String),
InvalidDims(String),
OutOfBounds(String),
}
impl std::error::Error for OpError {}
impl fmt::Display for OpError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
OpError::NdArrayError(pref, e) => write!(f, "{}: ", pref).and_then(|()| e.fmt(f)),
OpError::IncompatibleShape(s) => write!(f, "{}: ", s),
OpError::TypeUnsupported(s) => write!(f, "{}: ", s),
OpError::InvalidDims(s) => write!(f, "{}: ", s),
OpError::OutOfBounds(s) => write!(f, "{}: ", s),
}
}
}
pub(crate) type Results<'v, T> = OutputArray<Option<ComputeResult<'v, T>>>;
pub trait Op<F: Float> {
fn name(&self) -> &str {
type_name::<Self>()
}
fn compute(&self, ctx: &mut ComputeContext<F>);
fn grad(&self, ctx: &mut GradientContext<F>);
}
pub(crate) struct DummyOp<F: Float> {
pub phantom: PhantomData<F>,
}
impl<F: Float> DummyOp<F> {
#[allow(dead_code)]
pub(crate) fn new() -> Self {
DummyOp {
phantom: PhantomData,
}
}
}
impl<F: Float> Op<F> for DummyOp<F> {
fn compute(&self, _: &mut ComputeContext<F>) {}
fn grad(&self, _: &mut GradientContext<F>) {}
}
pub(crate) enum OpInput<'v, T: Float> {
RO(Option<NdArrayView<'v, T>>),
RW(Option<NdArrayViewMut<'v, T>>),
}
impl<'v, T: Float> OpInput<'v, T> {
#[inline]
pub fn new(x: NdArrayView<'v, T>) -> Self {
OpInput::RO(Some(x))
}
#[inline]
pub fn new_mut(x: NdArrayViewMut<'v, T>) -> Self {
OpInput::RW(Some(x))
}
}
pub struct ComputeContext<'t, 'v, T: Float> {
node: &'t TensorInternal<T>,
xs: InputArray<OpInput<'v, T>>,
ys: Results<'v, T>,
}
impl<'g, 't, 'v, T: Float> ComputeContext<'t, 'v, T> {
pub(crate) fn extract_outputs(self) -> Results<'v, T> {
self.ys
}
#[inline]
pub(crate) fn new(node: &'t TensorInternal<T>, xs: InputArray<OpInput<'v, T>>) -> Self {
ComputeContext {
node,
xs,
ys: OutputArray::new(),
}
}
#[inline]
pub fn input(&mut self, i: usize) -> NdArrayView<'v, T> {
let x = match self.xs.get_mut(i) {
Some(x) => x,
None => panic!("Bad op impl: input index out of range."),
};
match x {
OpInput::RO(ref mut a) => match a.take() {
Some(ret) => ret,
None => panic!(
"Bad op impl of {}: input({})/input_mut({}) cannot be called twice",
self.node.get_op().name(),
i,
i
),
},
_ => {
panic!(
"Bad op impl of {}: cannot perform immutable borrowing for input({})",
self.node.get_op().name(),
i
);
}
}
}
#[inline]
pub fn input_mut(&mut self, i: usize) -> NdArrayViewMut<'v, T> {
let x = match self.xs.get_mut(i) {
Some(x) => x,
None => panic!("Bad op impl: {}'s input doesn't exist.", i),
};
match x {
OpInput::RW(ref mut a) => match a.take() {
Some(ret) => ret,
None => panic!(
"Bad op impl of {}: input({})/input_mut({}) cannot be called twice",
self.node.get_op().name(),
i,
i
),
},
_ => {
panic!(
"Bad op impl of {}: cannot perform mutable borrowing for input({})",
self.node.get_op().name(),
i
);
}
}
}
#[inline]
pub fn append_output_view(&mut self, y: NdArrayView<'v, T>) {
self.ys.push(Some(Ok(crate::ArrRepr::View(y))));
}
#[inline]
pub fn append_output(&mut self, y: NdArray<T>) {
self.ys.push(Some(Ok(crate::ArrRepr::Owned(y))));
}
#[inline]
pub fn append_empty_output(&mut self) {
self.ys.push(None);
}
#[inline]
pub fn set_error(&mut self, y: OpError) {
self.ys.push(Some(Err(y)));
}
#[inline]
pub fn num_inputs(&self) -> usize {
self.xs.len()
}
}
pub struct GradientContext<'g, T: Float> {
gy: Tensor<'g, T>,
y: Tensor<'g, T>,
graph: &'g crate::graph::Graph<T>,
gxs: InputArray<Option<Tensor<'g, T>>>,
}
impl<'g, T: Float> GradientContext<'g, T> {
pub(crate) fn new(
gy: Tensor<'g, T>,
y: Tensor<'g, T>,
graph: &'g crate::graph::Graph<T>,
) -> Self {
GradientContext {
gy,
y,
graph,
gxs: InputArray::new(),
}
}
pub(crate) fn extract_input_grads(mut self) -> InputArray<Option<Tensor<'g, T>>> {
let id = self.y.id;
unsafe {
let stolen = mem::replace(&mut self.graph().access_inner_mut(id).op, None).unwrap();
stolen.grad(&mut self);
mem::swap(&mut self.graph().access_inner_mut(id).op, &mut Some(stolen));
debug_assert!(
!self.gxs.is_empty(),
"Bad Op impl: GradientContext::append_input_grad was not called"
);
self.gxs
}
}
#[inline]
pub fn output_grad(&self) -> Tensor<'g, T> {
self.gy
}
#[inline]
pub fn output(&self) -> Tensor<'g, T> {
self.y
}
#[inline]
pub fn input(&self, i: usize) -> Tensor<'g, T> {
return self
.y
.input_tensor(i, self.graph)
.expect("bad Op::grad impl");
}
#[inline]
pub fn num_inputs(&self) -> usize {
unsafe { self.y.inner().in_edges.len() }
}
#[inline]
pub fn graph(&self) -> &'g crate::graph::Graph<T> {
self.graph
}
#[inline]
pub fn append_input_grad(&mut self, gx: Option<Tensor<'g, T>>) {
self.gxs.push(gx);
}
}