use crate::Float;
use crate::{op, Context};
use crate::{NdArray, NdArrayView};
use crate::error::OpError;
use crate::graph::{AsGraph, Graph, TensorID};
use crate::op::{GradientContext, SmallVec};
use crate::variable::VariableID;
use std::cell::Ref;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::marker::PhantomData;
use std::ops::{Add, Div, Mul, Sub};
#[derive(Clone, Copy)]
pub struct Tensor<'graph, F: Float> {
pub(crate) id: TensorID, pub(crate) graph: &'graph Graph<F>,
}
impl<F: Float + std::fmt::Debug> std::fmt::Debug for Tensor<'_, F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Tensor")
.field("id", &self.id)
.field("is_source", &self.is_source())
.field("is_differentiable", &self.is_differentiable())
.finish()
}
}
impl<F: Float> PartialEq for Tensor<'_, F> {
fn eq(&self, other: &Self) -> bool {
self.id == other.id && std::ptr::eq(self.graph, other.graph)
}
}
impl<'graph, F: Float> Tensor<'graph, F> {
#[inline]
#[allow(dead_code)]
pub(crate) fn get_incoming_tensors(&self) -> Ref<SmallVec<IncomingTensor>> {
Ref::map(self.inner(), |x| &x.incoming_nodes)
}
pub(crate) fn get_incoming_tensor(
&self,
i: usize,
g: &'graph Graph<F>,
) -> Option<Tensor<'graph, F>> {
self.inner().incoming_nodes.get(i).map(|x| x.as_tensor(g))
}
#[inline(always)]
pub(crate) fn inner(&self) -> Ref<TensorInternal<F>> {
self.graph.access_inner(self.id)
}
#[inline]
pub(crate) fn graph(&self) -> &'graph Graph<F> {
self.graph
}
pub fn eval(&self, ctx: &Context<F>) -> Result<NdArray<F>, crate::EvalError> {
crate::graph::assert_same_graph(ctx, self.graph);
let result = ctx.evaluator().eval(self);
result
}
#[inline]
pub fn depends_on<A>(self, on: &[A]) -> Tensor<'graph, F>
where
A: AsRef<Tensor<'graph, F>> + Copy,
{
crate::tensor_ops::control_dependencies(self, on)
}
#[inline]
pub fn builder(graph: &'graph impl AsGraph<F>) -> TensorBuilder<'graph, F> {
TensorBuilder {
graph: graph.as_graph(),
shape: None,
in_nodes: SmallVec::new(),
differentiable: true,
placeholder_name: None,
backprop_inputs: None,
knownshape: None,
variable_id: None,
}
}
pub fn map(&self, f: fn(NdArrayView<F>) -> NdArray<F>) -> Tensor<'graph, F> {
crate::tensor_ops::map(self, f)
}
#[inline]
pub fn register_hook<H: crate::hooks::Hook<F> + 'static>(self, hook: H) -> Tensor<'graph, F> {
Tensor::builder(self.graph)
.append_input(self, false)
.build(crate::tensor_ops::hook_ops::HookOp::new(hook))
}
#[inline]
pub fn show(self) -> Tensor<'graph, F> {
self.register_hook(crate::hooks::Show)
}
#[inline]
pub fn show_prefixed(self, prefix: &'static str) -> Tensor<'graph, F> {
self.register_hook(crate::hooks::ShowPrefixed(prefix))
}
#[inline]
pub fn showshape(self) -> Tensor<'graph, F> {
self.register_hook(crate::hooks::ShowShape)
}
#[inline]
pub fn show_prefixedshape(self, prefix: &'static str) -> Tensor<'graph, F> {
self.register_hook(crate::hooks::ShowPrefixedShape(prefix))
}
#[inline]
pub fn print(self, what: &'static str) -> Tensor<'graph, F> {
self.register_hook(crate::hooks::Print(what))
}
#[inline]
pub fn raw_hook<FUN: Fn(&NdArrayView<F>) + 'static + Send + Sync>(
self,
f: FUN,
) -> Tensor<'graph, F> {
self.register_hook(crate::hooks::Raw {
raw: f,
phantom: PhantomData,
})
}
#[inline(always)]
pub fn id(&self) -> usize {
self.id
}
#[inline]
pub fn num_inputs(&self) -> usize {
self.inner().num_inputs()
}
#[inline]
pub fn num_backprop_inputs(&self) -> usize {
let inner = self.inner();
inner
.backprop_inputs
.as_ref()
.unwrap_or(&inner.incoming_nodes)
.len()
}
#[inline]
pub fn is_source(&self) -> bool {
self.inner().is_source()
}
#[inline]
pub(crate) fn get_variable_id(&self) -> Option<VariableID> {
self.inner().variable_id
}
#[inline]
pub fn get_backprop_input(&self, idx: usize) -> Tensor<'graph, F> {
self.graph
.tensor(self.inner().get_backprop_inputs()[idx].id)
}
#[inline]
pub fn is_placeholder(&self) -> bool {
self.inner().placeholder_name.is_some()
}
#[inline]
pub fn placeholder_name(&self) -> Option<&str> {
self.inner().placeholder_name
}
#[inline]
pub fn validate_using_knownshape(&self, shape: &[usize]) {
if let Some(ref knownshape) = self.inner().knownshape {
if !knownshape.validate(shape) {
panic!(
"Shape error: placeholder required {:?}, but got {:?}",
knownshape.get(),
shape
);
}
} else {
panic!("This is not a placeholder");
}
}
#[inline]
pub fn is_differentiable(&self) -> bool {
self.inner().is_differentiable
}
#[inline]
#[allow(unused)]
pub(crate) fn is_variable(&self) -> bool {
self.inner().is_variable()
}
pub fn shape(&self) -> Vec<usize> {
if let Some(ref knownshape) = self.inner().knownshape {
knownshape
.get()
.iter()
.map(|&x| x.max(0) as usize)
.collect()
} else {
vec![]
}
}
pub fn data(&self) -> Vec<F> {
vec![]
}
pub fn from_vec(
_data: Vec<F>,
shape: Vec<usize>,
graph: &'graph Graph<F>,
) -> Tensor<'graph, F> {
let array = match NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(&shape), _data) {
Ok(arr) => arr,
Err(_) => NdArray::zeros(scirs2_core::ndarray::IxDyn(&shape)),
};
crate::tensor_ops::convert_to_tensor(array, graph)
}
pub fn requires_grad(&self) -> bool {
self.is_differentiable()
}
pub fn to_vec(&self) -> Vec<usize> {
self.shape()
}
#[inline]
pub fn detach(&self) -> Tensor<'graph, F> {
crate::tensor_ops::stop_gradient(*self)
}
#[inline]
pub fn with_grad(&self, requires_grad: bool) -> Tensor<'graph, F> {
if requires_grad {
*self
} else {
self.detach()
}
}
}
impl<'b, T: Float> AsRef<Tensor<'b, T>> for Tensor<'b, T> {
#[inline(always)]
fn as_ref(&self) -> &Tensor<'b, T> {
self
}
}
pub(crate) struct TensorInternal<F: Float> {
pub(crate) id: usize,
pub(crate) op: Option<Box<dyn op::Op<F>>>,
pub(crate) incoming_nodes: SmallVec<IncomingTensor>,
pub(crate) topo_rank: usize,
pub(crate) shape: Option<usize>,
pub(crate) placeholder_name: Option<&'static str>,
pub(crate) is_differentiable: bool,
pub(crate) backprop_inputs: Option<SmallVec<IncomingTensor>>,
pub(crate) knownshape: Option<KnownShape>,
pub(crate) variable_id: Option<VariableID>,
}
impl<F: Float> TensorInternal<F> {
#[allow(dead_code)]
pub fn new() -> Self {
TensorInternal {
id: 0,
op: Some(Box::new(Dummy)),
incoming_nodes: SmallVec::new(),
topo_rank: 0,
shape: None,
placeholder_name: None,
is_differentiable: true,
backprop_inputs: None,
knownshape: None,
variable_id: None,
}
}
pub fn get_op(&self) -> &dyn op::Op<F> {
self.op
.as_ref()
.expect("bad impl: Op is now stolen in gradient.rs")
.as_ref()
}
#[inline(always)]
pub fn id(&self) -> usize {
self.id
}
#[inline]
pub(crate) fn is_source(&self) -> bool {
self.incoming_nodes.is_empty()
}
#[inline]
pub(crate) fn is_variable(&self) -> bool {
self.variable_id.is_some()
}
#[inline]
pub(crate) fn num_inputs(&self) -> usize {
self.incoming_nodes.len()
}
#[inline]
#[allow(dead_code)]
pub fn is_differentiable(&self) -> bool {
self.is_differentiable
}
#[inline]
pub(crate) fn get_backprop_inputs(&self) -> &[IncomingTensor] {
self.backprop_inputs
.as_ref()
.unwrap_or(&self.incoming_nodes)
.as_slice()
}
}
impl<T: Float> fmt::Debug for TensorInternal<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Node name: {}, id: {}, num of inputs: {}, in-edges: {:?}",
self.get_op().name(),
self.id(),
self.incoming_nodes.len(),
self.incoming_nodes
)
}
}
impl<T: Float> Eq for TensorInternal<T> {}
impl<T: Float> PartialEq for TensorInternal<T> {
#[inline(always)]
fn eq(&self, other: &TensorInternal<T>) -> bool {
self.id() == other.id()
}
}
impl<T: Float> Hash for TensorInternal<T> {
#[inline(always)]
fn hash<H: Hasher>(&self, state: &mut H) {
self.id.hash(state)
}
}
impl<T: Float> AsRef<TensorInternal<T>> for TensorInternal<T> {
#[inline(always)]
fn as_ref(&self) -> &TensorInternal<T> {
self
}
}
impl<T: Float> fmt::Display for TensorInternal<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "name={}", self.get_op().name(),)
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub(crate) struct IncomingTensor {
pub(crate) id: usize,
pub(crate) allow_mut: bool,
pub(crate) array_selector: usize,
}
impl<'graph> IncomingTensor {
#[inline]
pub(crate) fn new<F: Float>(val: &Tensor<'graph, F>, arrayselector: usize) -> IncomingTensor {
IncomingTensor {
id: val.id(),
allow_mut: false,
array_selector: arrayselector,
}
}
#[inline]
pub(crate) fn new_mut<F: Float>(
val: &Tensor<'graph, F>,
array_selector: usize,
) -> IncomingTensor {
IncomingTensor {
id: val.id(),
allow_mut: true,
array_selector,
}
}
#[inline(always)]
pub(crate) fn as_tensor<F: Float>(&self, graph: &'graph Graph<F>) -> Tensor<'graph, F> {
graph.tensor(self.id)
}
#[inline]
#[allow(dead_code)]
pub(crate) fn get_variable_id<F: Float>(&self, graph: &Graph<F>) -> Option<VariableID> {
graph.access_inner(self.id).variable_id
}
}
pub struct TensorBuilder<'g, F: Float> {
graph: &'g Graph<F>,
shape: Option<usize>, in_nodes: SmallVec<IncomingTensor>,
differentiable: bool,
backprop_inputs: Option<SmallVec<IncomingTensor>>,
knownshape: Option<KnownShape>,
variable_id: Option<VariableID>,
placeholder_name: Option<&'static str>,
}
const NUM_MAX_KNOWN_SHAPE_SIZE: usize = 4;
type ShapeVec = smallvec::SmallVec<[isize; NUM_MAX_KNOWN_SHAPE_SIZE]>;
pub(crate) struct KnownShape {
shape: ShapeVec,
#[allow(dead_code)]
is_fully_defined: bool,
}
impl KnownShape {
pub(crate) fn new(shape: &[isize]) -> Self {
let mut is_fully_defined = true;
for &a in shape {
if a == -1 {
is_fully_defined = false;
} else if a == 0 {
} else if a < -1 {
panic!("Given shape ({:?}) contains invalid dim size(s)", &shape);
}
}
Self {
shape: ShapeVec::from(shape),
is_fully_defined,
}
}
#[inline]
pub fn get(&self) -> &[isize] {
self.shape.as_slice()
}
pub fn validate(&self, target: &[usize]) -> bool {
if self.shape.len() != target.len() {
return false;
}
for (&i, &u) in self.shape.iter().zip(target) {
if i > 0 && i as usize != u {
return false;
}
}
true
}
#[inline]
#[allow(dead_code)]
pub fn is_fully_defined(&self) -> bool {
self.is_fully_defined
}
}
#[test]
#[allow(dead_code)]
fn test_topo_order() {
use crate::tensor_ops as T;
crate::run(|g| {
let a: Tensor<f32> = T::zeros(&[4, 2], g);
let v: Tensor<f32> = T::zeros(&[2, 3], g);
let b: Tensor<f32> = T::zeros(&[4, 3], g);
let z = T::matmul(a, v) + b;
let mut vars = [a.inner(), v.inner(), b.inner(), z.inner()];
vars.sort_by_key(|a| a.topo_rank);
assert_eq!(vars[0].id, a.id);
assert_eq!(vars[1].id, v.id);
assert_eq!(vars[2].id, b.id);
assert_eq!(vars[3].id, z.id);
});
}
impl<'graph, F: Float> TensorBuilder<'graph, F> {
#[inline]
pub(crate) fn set_variable(mut self, s: VariableID) -> TensorBuilder<'graph, F> {
self.variable_id = Some(s);
self
}
#[inline]
pub(crate) fn set_knownshape(mut self, s: &[isize]) -> TensorBuilder<'graph, F> {
self.knownshape = Some(KnownShape::new(s));
self
}
#[inline]
pub(crate) fn setshape(mut self, s: &Tensor<'graph, F>) -> TensorBuilder<'graph, F> {
self.shape = Some(s.id());
self
}
#[inline]
pub fn set_differentiable(mut self, differentiable: bool) -> TensorBuilder<'graph, F> {
self.differentiable = differentiable;
self
}
#[inline]
pub fn append_input<T: AsRef<Tensor<'graph, F>>>(
self,
tensor: T,
allow_mut: bool,
) -> TensorBuilder<'graph, F> {
self.append_input_with_selector(tensor, allow_mut, 0)
}
#[inline]
pub(crate) fn append_input_with_selector<T: AsRef<Tensor<'graph, F>>>(
mut self,
tensor: T,
allow_mut: bool,
array_selector: usize,
) -> TensorBuilder<'graph, F> {
let t = tensor.as_ref();
crate::graph::assert_same_graph(t.graph, self.graph);
if allow_mut {
self.in_nodes
.push(IncomingTensor::new_mut(t, array_selector));
} else {
self.in_nodes.push(IncomingTensor::new(t, array_selector));
}
self
}
#[inline]
pub(crate) fn set_placeholder_name(mut self, a: &'static str) -> TensorBuilder<'graph, F> {
self.placeholder_name = Some(a);
self
}
#[inline]
pub fn append_backprop_input<T: AsRef<Tensor<'graph, F>>>(
mut self,
a: T,
) -> TensorBuilder<'graph, F> {
crate::graph::assert_same_graph(a.as_ref().graph, self.graph);
if let Some(ref mut inputs) = self.backprop_inputs {
inputs.push(IncomingTensor::new(a.as_ref(), 0));
} else {
let mut inputs = SmallVec::new();
inputs.push(IncomingTensor::new(a.as_ref(), 0));
self.backprop_inputs = Some(inputs);
}
self
}
pub fn build<O>(self, op: O) -> Tensor<'graph, F>
where
O: op::Op<F> + 'static,
{
let graph = self.graph;
let rank = if self.in_nodes.is_empty() {
0
} else {
self.in_nodes
.iter()
.map(|a| graph.access_inner(a.id).topo_rank)
.max()
.map(|a| a + 1)
.unwrap_or(0)
};
let new = TensorInternal {
id: usize::default(),
op: Some(Box::new(op)),
incoming_nodes: self.in_nodes,
topo_rank: rank,
shape: self.shape,
is_differentiable: self.differentiable,
backprop_inputs: self.backprop_inputs,
knownshape: self.knownshape,
variable_id: self.variable_id,
placeholder_name: self.placeholder_name,
};
Tensor {
id: graph.install(new),
graph,
}
}
}
#[allow(dead_code)]
pub(crate) struct Dummy;
impl<T: Float> op::Op<T> for Dummy {
fn compute(&self, _: &mut op::ComputeContext<T>) -> Result<(), OpError> {
Ok(())
}
fn grad(&self, _: &mut GradientContext<T>) {}
}
use crate::tensor_ops as T;
macro_rules! impl_bin_op_between_tensor_and_float_trait {
($trt:ident, $func:ident, $op:ident) => {
impl<'b, F: Float> $trt<F> for Tensor<'b, F> {
type Output = Tensor<'b, F>;
fn $func(self, rhs: F) -> Self::Output {
T::$func(&self, &T::scalar(rhs, self.graph))
}
}
impl<'l, 'b, F: Float> $trt<F> for &'l Tensor<'b, F> {
type Output = Tensor<'b, F>;
fn $func(self, rhs: F) -> Self::Output {
T::$func(self, &T::scalar(rhs, self.graph))
}
}
};
}
macro_rules! impl_bin_op_between_tensor_and_primitive {
($trt:ident, $func:ident, $op:ident, $scalar_type:ty) => {
impl<'r, 'b, F: Float> $trt<Tensor<'b, F>> for $scalar_type {
type Output = Tensor<'b, F>;
fn $func(self, rhs: Tensor<'b, F>) -> Self::Output {
T::$func(
&T::scalar(
F::from(self).expect("Failed to convert to float"),
rhs.graph,
),
&rhs,
)
}
}
impl<'r, 'b, F: Float> $trt<&'r Tensor<'b, F>> for $scalar_type {
type Output = Tensor<'b, F>;
fn $func(self, rhs: &'r Tensor<'b, F>) -> Self::Output {
T::$func(
&T::scalar(
F::from(self).expect("Failed to convert to float"),
rhs.graph,
),
rhs,
)
}
}
};
}
impl_bin_op_between_tensor_and_float_trait!(Add, add, AddOp);
impl_bin_op_between_tensor_and_float_trait!(Sub, sub, SubOp);
impl_bin_op_between_tensor_and_float_trait!(Mul, mul, MulOp);
impl_bin_op_between_tensor_and_float_trait!(Div, div, DivOp);
impl_bin_op_between_tensor_and_primitive!(Add, add, AddOp, f64);
impl_bin_op_between_tensor_and_primitive!(Sub, sub, SubOp, f64);
impl_bin_op_between_tensor_and_primitive!(Mul, mul, MulOp, f64);
impl_bin_op_between_tensor_and_primitive!(Div, div, DivOp, f64);
impl_bin_op_between_tensor_and_primitive!(Add, add, AddOp, f32);
impl_bin_op_between_tensor_and_primitive!(Sub, sub, SubOp, f32);
impl_bin_op_between_tensor_and_primitive!(Mul, mul, MulOp, f32);
impl_bin_op_between_tensor_and_primitive!(Div, div, DivOp, f32);
macro_rules! impl_bin_op_between_tensors {
($trt:ident, $func:ident, $op:ident) => {
impl<'b, F: Float> $trt for Tensor<'b, F> {
type Output = Tensor<'b, F>;
fn $func(self, rhs: Tensor<'b, F>) -> Self::Output {
T::$func(&self, &rhs)
}
}
impl<'r, 'b, F: Float> $trt<&'r Tensor<'b, F>> for Tensor<'b, F> {
type Output = Tensor<'b, F>;
fn $func(self, rhs: &'r Tensor<'b, F>) -> Self::Output {
T::$func(&self, rhs)
}
}
impl<'l, 'b, F: Float> $trt<Tensor<'b, F>> for &'l Tensor<'b, F> {
type Output = Tensor<'b, F>;
fn $func(self, rhs: Tensor<'b, F>) -> Self::Output {
T::$func(self, &rhs)
}
}
impl<'l, 'r, 'b, F: Float> $trt<&'r Tensor<'b, F>> for &'l Tensor<'b, F> {
type Output = Tensor<'b, F>;
fn $func(self, rhs: &'r Tensor<'b, F>) -> Self::Output {
T::$func(self, rhs)
}
}
};
}
impl_bin_op_between_tensors!(Add, add, AddOp);
impl_bin_op_between_tensors!(Sub, sub, SubOp);
impl_bin_op_between_tensors!(Mul, mul, MulOp);
impl_bin_op_between_tensors!(Div, div, DivOp);
pub trait AsTensor<'graph, F: Float> {
fn as_tensor(&self, graph: &'graph impl AsGraph<F>) -> Tensor<'graph, F>;
}
impl<'graph, F: Float> AsTensor<'graph, F> for Tensor<'graph, F> {
fn as_tensor(&self, graph: &'graph impl AsGraph<F>) -> Tensor<'graph, F> {
*self
}
}
macro_rules! impl_as_tensor_for_array {
($num_elems:expr) => {
impl<'graph, F: Float, I: crate::Int> AsTensor<'graph, F> for [I; $num_elems] {
fn as_tensor(&self, graph: &'graph impl AsGraph<F>) -> Tensor<'graph, F> {
let vec = self
.iter()
.map(|&a| F::from(a).expect("Failed to convert to float"))
.collect::<Vec<F>>();
let arr = NdArray::from_shape_vec(scirs2_core::ndarray::IxDyn(&[self.len()]), vec)
.expect("Operation failed");
T::convert_to_tensor(arr, graph.as_graph())
}
}
};
}
impl_as_tensor_for_array!(0);
impl_as_tensor_for_array!(1);
impl_as_tensor_for_array!(2);
impl_as_tensor_for_array!(3);
impl_as_tensor_for_array!(4);
impl_as_tensor_for_array!(5);
impl_as_tensor_for_array!(6);
impl_as_tensor_for_array!(7);
impl_as_tensor_for_array!(8);