use smallvec::SmallVec;
use std::fmt::Debug;
use std::iter::repeat;
use std::mem::MaybeUninit;
use rten_base::num::{AsBool, Identities, IsInt};
use rten_shape_inference::ops as shape_ops;
use rten_tensor::prelude::*;
use rten_tensor::{AssumeInit, NdTensorView, NdTensorViewMut};
use rten_tensor::{Tensor, TensorView, TensorViewMut};
use crate::buffer_pool::{AutoReturn, BufferPool};
use crate::infer_shapes::{BinaryOp, InferShapes};
use crate::operator::{
IntoOpResult, OpError, OpRunContext, Operator, OutputList, OutputType, OutputTypeList,
OutputTypesContext,
};
use crate::ops::{map_value, map_value_view};
use crate::value::{DataType, Value, ValueType, ValueView};
pub fn broadcast_shapes(a: &[usize], b: &[usize]) -> Option<SmallVec<[usize; 4]>> {
let a_pad = b.len().saturating_sub(a.len());
let b_pad = a.len().saturating_sub(b.len());
let a_iter = a.iter().copied().rev().chain(repeat(1).take(a_pad));
let b_iter = b.iter().copied().rev().chain(repeat(1).take(b_pad));
let mut result = SmallVec::with_capacity(a.len().max(b.len()));
for (a, b) in a_iter.zip(b_iter) {
if a == b {
result.push(a);
} else if a == 1 {
result.push(b);
} else if b == 1 {
result.push(a);
} else {
return None;
}
}
result.reverse();
Some(result)
}
fn can_run_binary_op_in_place<L1: Layout, L2: Layout>(a: &L1, b: &L2) -> bool {
b.can_broadcast_to(a.shape().as_ref())
}
pub fn fast_broadcast_cycles_repeats(
from_shape: &[usize],
to_shape: &[usize],
) -> Option<(usize, usize)> {
if from_shape == to_shape {
return Some((1, 1));
}
if from_shape.iter().product::<usize>() == 1 {
return Some((1, to_shape.iter().product()));
}
assert!(to_shape.len() >= from_shape.len());
let from_pad = to_shape.len() - from_shape.len();
let from_size = |dim| {
if dim < from_pad {
1
} else {
from_shape[dim - from_pad]
}
};
let mut leading_1s = 0; let mut leading_bcast = 0; let mut cycles = 1;
for (i, to) in to_shape.iter().copied().enumerate() {
let from = from_size(i);
if from == 1 && to == 1 {
leading_1s += 1;
} else if from == 1 && to > 1 {
leading_bcast += 1;
cycles *= to;
} else {
break;
}
}
let mut trailing_1s = 0; let mut trailing_bcast = 0; let mut repeats = 1;
for (i, to) in to_shape.iter().copied().enumerate().rev() {
let from = from_size(i);
if from == 1 && to == 1 {
trailing_1s += 1;
} else if from == 1 && to > 1 {
trailing_bcast += 1;
repeats *= to;
} else {
break;
}
}
for i in (leading_1s + leading_bcast)..(to_shape.len() - trailing_1s - trailing_bcast) {
let from = from_size(i);
let to = to_shape[i];
if from != to {
return None;
}
}
Some((cycles, repeats))
}
pub trait BinaryKernel<T, U = T, R = T> {
fn apply_fast<'dst>(
&self,
a: &[T],
b: &[U],
out: &'dst mut [MaybeUninit<R>],
cycles: usize,
repeats: usize,
) -> &'dst mut [R];
fn apply_indexed<'dst>(
&self,
a: NdTensorView<T, 4>,
b: NdTensorView<U, 4>,
out: &'dst mut [MaybeUninit<R>],
) -> &'dst mut [R];
}
impl<T: Copy, U: Copy, R, F: Fn(T, U) -> R> BinaryKernel<T, U, R> for F {
fn apply_fast<'dst>(
&self,
a: &[T],
b: &[U],
out: &'dst mut [MaybeUninit<R>],
cycles: usize,
repeats: usize,
) -> &'dst mut [R] {
assert_eq!(a.len(), out.len());
assert_eq!(a.len(), b.len() * cycles * repeats);
let mut i = 0;
for _ in 0..cycles {
if repeats == 1 {
for b_elt in b {
let (a_elt, out_elt) =
unsafe { (*a.get_unchecked(i), out.get_unchecked_mut(i)) };
out_elt.write(self(a_elt, *b_elt));
i += 1;
}
} else {
for b_elt in b {
for _ in 0..repeats {
let (a_elt, out_elt) =
unsafe { (*a.get_unchecked(i), out.get_unchecked_mut(i)) };
out_elt.write(self(a_elt, *b_elt));
i += 1;
}
}
}
}
unsafe { out.assume_init() }
}
fn apply_indexed<'dst>(
&self,
a: NdTensorView<T, 4>,
b: NdTensorView<U, 4>,
out: &'dst mut [MaybeUninit<R>],
) -> &'dst mut [R] {
assert_eq!(a.shape(), b.shape());
assert_eq!(a.len(), out.len());
let mut out_offset = 0;
for i0 in 0..a.size(0) {
for i1 in 0..a.size(1) {
for i2 in 0..a.size(2) {
for i3 in 0..a.size(3) {
unsafe {
let a_elt = a.get_unchecked([i0, i1, i2, i3]);
let b_elt = b.get_unchecked([i0, i1, i2, i3]);
out.get_unchecked_mut(out_offset)
.write(self(*a_elt, *b_elt));
out_offset += 1;
}
}
}
}
}
unsafe { out.assume_init() }
}
}
pub fn binary_op<T: Copy, U: Copy, R>(
pool: &BufferPool,
a: TensorView<T>,
b: TensorView<U>,
op: &dyn BinaryKernel<T, U, R>,
) -> Result<Tensor<R>, OpError> {
let out_shape = broadcast_shapes(a.shape(), b.shape())
.ok_or(OpError::IncompatibleInputShapes("Cannot broadcast inputs"))?;
if a.shape() == out_shape.as_slice()
&& let (Some(a_data), Some(b_data)) = (a.data(), b.data())
&& let Some((cycles, repeats)) = fast_broadcast_cycles_repeats(b.shape(), a.shape())
{
assert!(cycles * b_data.len() * repeats == a.len());
let mut output = Tensor::uninit_in(pool, &out_shape);
let out_init = op.apply_fast(a_data, b_data, output.data_mut().unwrap(), cycles, repeats);
assert!(out_init.len() == output.len());
let output = unsafe { output.assume_init() };
return Ok(output);
}
let mut a = a.broadcast(out_shape.as_slice());
let mut b = b.broadcast(out_shape.as_slice());
let mut out_data = pool.alloc(a.len());
let out_uninit = &mut out_data.spare_capacity_mut()[..a.len()];
let mut out_offset = 0;
while a.ndim() <= 4 {
a.insert_axis(0);
b.insert_axis(0);
}
a.inner_iter::<4>()
.zip(b.inner_iter::<4>())
.for_each(|(a, b)| {
let len = a.len();
let out_init = op.apply_indexed(a, b, &mut out_uninit[out_offset..out_offset + len]);
out_offset += out_init.len();
});
unsafe {
out_data.set_len(out_offset);
}
Ok(Tensor::from_data(&out_shape, out_data))
}
pub trait BinaryMutKernel<T, U> {
fn apply_fast(&self, a: &mut [T], b: &[U], cycles: usize, repeats: usize);
fn apply_indexed(&self, a: NdTensorViewMut<T, 4>, b: NdTensorView<U, 4>);
}
impl<T: Copy, U: Copy, F: Fn(T, U) -> T> BinaryMutKernel<T, U> for F {
fn apply_fast(&self, a: &mut [T], b: &[U], cycles: usize, repeats: usize) {
assert!(cycles * b.len() * repeats == a.len());
let mut i = 0;
for _ in 0..cycles {
if repeats == 1 {
for b_elt in b {
let a_elt = unsafe { a.get_unchecked_mut(i) };
*a_elt = self(*a_elt, *b_elt);
i += 1;
}
} else {
for b_elt in b {
for _ in 0..repeats {
let a_elt = unsafe { a.get_unchecked_mut(i) };
*a_elt = self(*a_elt, *b_elt);
i += 1;
}
}
}
}
}
fn apply_indexed(&self, mut a: NdTensorViewMut<T, 4>, b: NdTensorView<U, 4>) {
assert_eq!(a.shape(), b.shape());
for i0 in 0..a.size(0) {
for i1 in 0..a.size(1) {
for i2 in 0..a.size(2) {
for i3 in 0..a.size(3) {
unsafe {
let a_elt = a.get_unchecked_mut([i0, i1, i2, i3]);
let b_elt = b.get_unchecked([i0, i1, i2, i3]);
*a_elt = self(*a_elt, *b_elt);
}
}
}
}
}
}
}
fn binary_op_in_place<T: Copy + Debug, U: Copy + Debug>(
mut a: TensorViewMut<T>,
b: TensorView<U>,
op: &dyn BinaryMutKernel<T, U>,
) {
if let Some(b_data) = b.data()
&& let Some((cycles, repeats)) = fast_broadcast_cycles_repeats(b.shape(), a.shape())
&& let Some(a_data) = a.data_mut()
{
op.apply_fast(a_data, b_data, cycles, repeats);
return;
}
let mut b = b.broadcast(a.shape());
while a.ndim() <= 4 {
a.insert_axis(0);
b.insert_axis(0);
}
a.inner_iter_mut::<4>()
.zip(b.inner_iter::<4>())
.for_each(|(a, b)| {
op.apply_indexed(a, b);
});
}
fn binary_commutative_op<T: Copy + Debug + Default>(
pool: &BufferPool,
a: TensorView<T>,
b: TensorView<T>,
op: &dyn BinaryKernel<T>,
) -> Result<Tensor<T>, OpError> {
if b.len() > a.len() {
binary_op(pool, b, a, op)
} else {
binary_op(pool, a, b, op)
}
}
macro_rules! run_typed_op {
($pool:expr, $inputs:expr, $op_func:ident) => {{
let a = $inputs.require(0)?;
map_value_view!(a, a, [FloatTensor, Int32Tensor], {
let b = $inputs.require_as(1)?;
$op_func($pool, a, b).into_op_result()
})
}};
($inputs:expr, $op_func:ident) => {
run_typed_op!(&BufferPool::new(), $inputs, $op_func)
};
}
macro_rules! run_typed_op_in_place {
($pool:expr, $input:expr, $other: expr, $in_place_op_func:ident, $op_func:ident) => {{
map_value!($input, a, [FloatTensor, Int32Tensor], {
let b = $other.require_as(0)?;
if can_run_binary_op_in_place(&a, &b) {
$in_place_op_func(a.view_mut(), b);
Ok(a.into())
} else {
let a = a.auto_return($pool);
$op_func($pool, a.view(), b.view()).map(|t| t.into())
}
})
}};
}
pub fn add<T: Copy + Debug + Default + std::ops::Add<Output = T>>(
pool: &BufferPool,
a: TensorView<T>,
b: TensorView<T>,
) -> Result<Tensor<T>, OpError> {
binary_commutative_op(pool, a, b, &|x, y| x + y)
}
pub fn add_in_place<T: Copy + Debug + std::ops::Add<Output = T>>(
a: TensorViewMut<T>,
b: TensorView<T>,
) {
binary_op_in_place(a, b, &|x, y| x + y);
}
#[derive(Debug)]
pub struct Add {}
impl Operator for Add {
fn name(&self) -> &str {
"Add"
}
fn max_inputs(&self) -> Option<usize> {
Some(2)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
run_typed_op!(ctx.pool(), ctx.inputs(), add)
}
fn can_run_in_place(&self) -> bool {
true
}
fn is_commutative(&self) -> bool {
true
}
fn run_in_place(&self, input: Value, ctx: &OpRunContext) -> Result<Value, OpError> {
run_typed_op_in_place!(ctx.pool(), input, ctx.inputs(), add_in_place, add)
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(&shape_ops::Add)
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
}
macro_rules! logical_boolean_op {
($op:ident, $op_fn:ident, $expr:expr) => {
pub fn $op_fn<T: AsBool + Copy + Debug>(
pool: &BufferPool,
a: TensorView<T>,
b: TensorView<T>,
) -> Result<Tensor<i32>, OpError> {
#[allow(clippy::redundant_closure_call)]
binary_op(pool, a, b, &|x: T, y: T| {
$expr(x.as_bool(), y.as_bool()).into()
})
}
#[derive(Debug)]
pub struct $op {}
impl Operator for $op {
fn name(&self) -> &str {
stringify!($op)
}
fn max_inputs(&self) -> Option<usize> {
Some(2)
}
fn is_commutative(&self) -> bool {
true
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let inputs = ctx.inputs();
let a: TensorView<i32> = inputs.require_as(0)?;
let b: TensorView<i32> = inputs.require_as(1)?;
$op_fn(ctx.pool(), a, b).into_op_result()
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::Fixed(ValueType::Tensor(DataType::Int32))].into())
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(&BinaryOp)
}
}
};
}
logical_boolean_op!(And, and, |x, y| x && y);
logical_boolean_op!(Or, or, |x, y| x || y);
logical_boolean_op!(Xor, xor, |x, y| x ^ y);
pub fn div<
T: Copy
+ Debug
+ Default
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>
+ IsInt
+ Identities,
>(
pool: &BufferPool,
a: TensorView<T>,
b: TensorView<T>,
) -> Result<Tensor<T>, OpError> {
match (T::is_int(), b.item()) {
(false, Some(scalar)) => mul(pool, a, Tensor::from_scalar(T::one() / *scalar).view()),
_ => binary_op(pool, a, b, &|x, y| x / y),
}
}
pub fn div_in_place<
T: Copy + Debug + std::ops::Mul<Output = T> + std::ops::Div<Output = T> + IsInt + Identities,
>(
a: TensorViewMut<T>,
b: TensorView<T>,
) {
match (T::is_int(), b.item()) {
(false, Some(scalar)) => mul_in_place(a, Tensor::from_scalar(T::one() / *scalar).view()),
_ => binary_op_in_place(a, b, &|x, y| x / y),
}
}
#[derive(Debug)]
pub struct Div {}
impl Operator for Div {
fn name(&self) -> &str {
"Div"
}
fn max_inputs(&self) -> Option<usize> {
Some(2)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
run_typed_op!(ctx.pool(), ctx.inputs(), div)
}
fn can_run_in_place(&self) -> bool {
true
}
fn run_in_place(&self, input: Value, ctx: &OpRunContext) -> Result<Value, OpError> {
run_typed_op_in_place!(ctx.pool(), input, ctx.inputs(), div_in_place, div)
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(&shape_ops::Div)
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
}
enum BooleanOp {
Equal,
Less,
LessOrEqual,
Greater,
GreaterOrEqual,
}
fn boolean_op<T: Copy + Debug + PartialEq + PartialOrd>(
pool: &BufferPool,
a: TensorView<T>,
b: TensorView<T>,
op: BooleanOp,
) -> Result<Tensor<i32>, OpError> {
binary_op(pool, a, b, &|x, y| {
i32::from(match op {
BooleanOp::Equal => x == y,
BooleanOp::Less => x < y,
BooleanOp::LessOrEqual => x <= y,
BooleanOp::Greater => x > y,
BooleanOp::GreaterOrEqual => x >= y,
})
})
}
macro_rules! boolean_cmp_op {
($name:ident, $func:ident, $infer_shapes:expr) => {
pub fn $func<T: Copy + Debug + PartialEq + PartialOrd>(
pool: &BufferPool,
a: TensorView<T>,
b: TensorView<T>,
) -> Result<Tensor<i32>, OpError> {
boolean_op(pool, a, b, BooleanOp::$name)
}
#[derive(Debug)]
pub struct $name {}
impl Operator for $name {
fn name(&self) -> &str {
stringify!($name)
}
fn max_inputs(&self) -> Option<usize> {
Some(2)
}
fn is_commutative(&self) -> bool {
stringify!($name) == "Equal"
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
run_typed_op!(ctx.pool(), ctx.inputs(), $func)
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::Fixed(ValueType::Tensor(DataType::Int32))].into())
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
$infer_shapes(self)
}
}
};
}
boolean_cmp_op!(Equal, equal, |_| Some(
&shape_ops::Equal as &dyn InferShapes
));
boolean_cmp_op!(Greater, greater, |_| Some(&BinaryOp as &dyn InferShapes));
boolean_cmp_op!(GreaterOrEqual, greater_or_equal, |_| Some(
&BinaryOp as &dyn InferShapes
));
boolean_cmp_op!(Less, less, |_| Some(&BinaryOp as &dyn InferShapes));
boolean_cmp_op!(LessOrEqual, less_or_equal, |_| Some(
&BinaryOp as &dyn InferShapes
));
fn rem_floor<
T: Copy + Default + PartialOrd + std::ops::Add<Output = T> + std::ops::Rem<Output = T>,
>(
x: T,
y: T,
) -> T {
let zero = T::default();
let mut rem = x % y;
if rem > zero && y < zero || rem < zero && y > zero {
rem = rem + y;
}
rem
}
pub enum DivMode {
FloorDiv,
TruncDiv,
}
pub fn mod_op<
T: Copy + Debug + Default + PartialOrd + std::ops::Add<Output = T> + std::ops::Rem<Output = T>,
>(
pool: &BufferPool,
a: TensorView<T>,
b: TensorView<T>,
mode: DivMode,
) -> Result<Tensor<T>, OpError> {
binary_op(
pool,
a,
b,
match mode {
DivMode::FloorDiv => &|x, y| rem_floor(x, y),
DivMode::TruncDiv => &|x, y| x % y,
},
)
}
#[derive(Debug)]
pub struct Mod {
pub fmod: bool,
}
impl Operator for Mod {
fn name(&self) -> &str {
"Mod"
}
fn max_inputs(&self) -> Option<usize> {
Some(2)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let inputs = ctx.inputs();
let a = inputs.require(0)?;
let mode = if self.fmod {
DivMode::TruncDiv
} else {
DivMode::FloorDiv
};
map_value_view!(a, a, [FloatTensor, Int32Tensor], {
let b = inputs.require_as(1)?;
mod_op(ctx.pool(), a, b, mode).into_op_result()
})
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(&BinaryOp)
}
}
pub fn mul<T: Copy + Debug + Default + std::ops::Mul<Output = T>>(
pool: &BufferPool,
a: TensorView<T>,
b: TensorView<T>,
) -> Result<Tensor<T>, OpError> {
binary_commutative_op(pool, a, b, &|x, y| x * y)
}
pub fn mul_in_place<T: Copy + Debug + std::ops::Mul<Output = T>>(
a: TensorViewMut<T>,
b: TensorView<T>,
) {
binary_op_in_place(a, b, &|a_elt, b_elt| a_elt * b_elt);
}
#[derive(Debug)]
pub struct Mul {}
impl Operator for Mul {
fn name(&self) -> &str {
"Mul"
}
fn max_inputs(&self) -> Option<usize> {
Some(2)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
run_typed_op!(ctx.pool(), ctx.inputs(), mul)
}
fn can_run_in_place(&self) -> bool {
true
}
fn is_commutative(&self) -> bool {
true
}
fn run_in_place(&self, input: Value, ctx: &OpRunContext) -> Result<Value, OpError> {
run_typed_op_in_place!(ctx.pool(), input, ctx.inputs(), mul_in_place, mul)
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(&shape_ops::Mul)
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
}
pub trait FastPow<Exp>: Copy {
fn fast_pow(self, exponent: Exp) -> Self;
}
impl FastPow<f32> for f32 {
fn fast_pow(self, exponent: f32) -> Self {
if exponent == 2. {
self * self
} else if exponent == 3. {
self * self * self
} else {
self.powf(exponent)
}
}
}
impl FastPow<f32> for i32 {
fn fast_pow(self, exponent: f32) -> Self {
(self as f32).fast_pow(exponent) as i32
}
}
impl FastPow<i32> for i32 {
fn fast_pow(self, exponent: i32) -> Self {
if exponent == 2 {
self.wrapping_mul(self)
} else if exponent == 3 {
self.wrapping_mul(self).wrapping_mul(self)
} else if exponent >= 0 {
self.wrapping_pow(exponent as u32)
} else {
self.fast_pow(exponent as f32)
}
}
}
pub fn pow<E: Copy, T: FastPow<E>>(
pool: &BufferPool,
base: TensorView<T>,
exp: TensorView<E>,
) -> Result<Tensor<T>, OpError> {
if let Some(&exp) = exp.item() {
Ok(base.map_in(pool, |x| x.fast_pow(exp)))
} else {
binary_op(pool, base, exp, &|b: T, e: E| b.fast_pow(e))
}
}
pub fn pow_in_place<E: Copy + Debug, T: Copy + Debug + FastPow<E>>(
mut base: TensorViewMut<T>,
exp: TensorView<E>,
) {
if let Some(exp) = exp.item() {
base.apply(|x| x.fast_pow(*exp))
} else {
binary_op_in_place(base, exp, &|b: T, e: E| b.fast_pow(e));
}
}
#[derive(Debug)]
pub struct Pow {}
impl Pow {
fn eval(
&self,
ctx: &OpRunContext,
base: ValueView,
exponent: ValueView,
) -> Result<Value, OpError> {
match (base, exponent) {
(ValueView::FloatTensor(base), ValueView::FloatTensor(exponent)) => {
pow(ctx.pool(), base, exponent).map(|t| t.into())
}
(ValueView::Int32Tensor(base), ValueView::FloatTensor(exponent)) => {
pow(ctx.pool(), base, exponent).map(|t| t.into())
}
(ValueView::Int32Tensor(base), ValueView::Int32Tensor(exponent)) => {
pow(ctx.pool(), base, exponent).map(|t| t.into())
}
_ => Err(OpError::UnsupportedValue(
"Unsupported base and exponent type combination",
)),
}
}
}
impl Operator for Pow {
fn name(&self) -> &str {
"Pow"
}
fn max_inputs(&self) -> Option<usize> {
Some(2)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let inputs = ctx.inputs();
let base = inputs.require(0)?;
let exponent = inputs.require(1)?;
self.eval(ctx, base, exponent).into_op_result()
}
fn can_run_in_place(&self) -> bool {
true
}
fn run_in_place(&self, base: Value, ctx: &OpRunContext) -> Result<Value, OpError> {
let exponent = ctx.inputs().require(0)?;
if can_run_binary_op_in_place(&base, &exponent) {
match (base, exponent) {
(Value::FloatTensor(mut base), ValueView::FloatTensor(exponent)) => {
pow_in_place(base.view_mut(), exponent);
Ok(base.into())
}
(Value::Int32Tensor(mut base), ValueView::FloatTensor(exponent)) => {
pow_in_place(base.view_mut(), exponent);
Ok(base.into())
}
(Value::Int32Tensor(mut base), ValueView::Int32Tensor(exponent)) => {
pow_in_place(base.view_mut(), exponent);
Ok(base.into())
}
_ => Err(OpError::UnsupportedValue(
"Unsupported base and exponent type combination",
)),
}
} else {
self.eval(ctx, base.as_view(), exponent)
}
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(&BinaryOp)
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
}
pub fn sub<T: Copy + Debug + Default + std::ops::Sub<Output = T>>(
pool: &BufferPool,
a: TensorView<T>,
b: TensorView<T>,
) -> Result<Tensor<T>, OpError> {
binary_op(pool, a, b, &|x, y| x - y)
}
pub fn sub_in_place<T: Copy + Debug + std::ops::Sub<Output = T>>(
a: TensorViewMut<T>,
b: TensorView<T>,
) {
binary_op_in_place(a, b, &|x, y| x - y);
}
#[derive(Debug)]
pub struct Sub {}
impl Operator for Sub {
fn name(&self) -> &str {
"Sub"
}
fn max_inputs(&self) -> Option<usize> {
Some(2)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
run_typed_op!(ctx.pool(), ctx.inputs(), sub)
}
fn can_run_in_place(&self) -> bool {
true
}
fn run_in_place(&self, input: Value, ctx: &OpRunContext) -> Result<Value, OpError> {
run_typed_op_in_place!(ctx.pool(), input, ctx.inputs(), sub_in_place, sub)
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(&shape_ops::Sub)
}
}
pub fn where_op<T: Copy>(
pool: &BufferPool,
cond: TensorView<i32>,
x: TensorView<T>,
y: TensorView<T>,
) -> Result<Tensor<T>, OpError> {
let broadcast_xy_shape = broadcast_shapes(x.shape(), y.shape())
.ok_or(OpError::IncompatibleInputShapes("Cannot broadcast inputs"))?;
let result_shape = broadcast_shapes(cond.shape(), &broadcast_xy_shape)
.ok_or(OpError::IncompatibleInputShapes("Cannot broadcast inputs"))?;
let out_len = result_shape.iter().product();
let mut out_data = pool.alloc(out_len);
let mut cond = cond.broadcast(result_shape.as_slice());
let mut x = x.broadcast(result_shape.as_slice());
let mut y = y.broadcast(result_shape.as_slice());
while cond.ndim() <= 4 {
cond.insert_axis(0);
x.insert_axis(0);
y.insert_axis(0);
}
let out_uninit = &mut out_data.spare_capacity_mut()[..cond.len()];
let mut out_offset = 0;
cond.inner_iter::<4>()
.zip(x.inner_iter::<4>().zip(y.inner_iter::<4>()))
.for_each(|(cond, (x, y))| {
for i0 in 0..cond.size(0) {
for i1 in 0..cond.size(1) {
for i2 in 0..cond.size(2) {
for i3 in 0..cond.size(3) {
unsafe {
let cond_elt = *cond.get_unchecked([i0, i1, i2, i3]);
let x_elt = *x.get_unchecked([i0, i1, i2, i3]);
let y_elt = *y.get_unchecked([i0, i1, i2, i3]);
let out_elt = if cond_elt != 0 { x_elt } else { y_elt };
out_uninit.get_unchecked_mut(out_offset).write(out_elt);
out_offset += 1;
}
}
}
}
}
});
assert!(out_offset == cond.len());
unsafe {
out_data.set_len(cond.len());
}
Ok(Tensor::from_data(&result_shape, out_data))
}
#[derive(Debug)]
pub struct Where {}
impl Operator for Where {
fn name(&self) -> &str {
"Where"
}
fn max_inputs(&self) -> Option<usize> {
Some(3)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let inputs = ctx.inputs();
let condition = inputs.require_as(0)?;
let x = inputs.require(1)?;
map_value_view!(x, x, [FloatTensor, Int32Tensor], {
let y = inputs.require_as(2)?;
where_op(ctx.pool(), condition, x, y).into_op_result()
})
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(1)].into())
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(&shape_ops::Where)
}
}
#[cfg(test)]
mod tests {
use std::error::Error;
use rten_tensor::prelude::*;
use rten_tensor::test_util::expect_equal;
use rten_tensor::{Scalar, Tensor};
use rten_testing::TestCases;
use super::fast_broadcast_cycles_repeats;
use super::{
Add, DivMode, add, add_in_place, and, div, div_in_place, equal, greater, greater_or_equal,
less, less_or_equal, mod_op, mul, mul_in_place, or, pow, pow_in_place, sub, sub_in_place,
where_op, xor,
};
use crate::buffer_pool::BufferPool;
use crate::operator::{InputList, OpError, OpRunContext, Operator, OperatorExt};
use crate::value::Value;
#[test]
fn test_fast_broadcast_cycles_repeats() {
let params = fast_broadcast_cycles_repeats(&[], &[1, 2, 3]);
assert_eq!(params, Some((1, 6)));
let params = fast_broadcast_cycles_repeats(&[1, 1, 1], &[5, 6, 2]);
assert_eq!(params, Some((1, 60)));
let params = fast_broadcast_cycles_repeats(&[3, 4, 5], &[3, 4, 5]);
assert_eq!(params, Some((1, 1)));
let params = fast_broadcast_cycles_repeats(&[1, 1, 10], &[5, 2, 10]);
assert_eq!(params, Some((10, 1)));
let params = fast_broadcast_cycles_repeats(&[10, 1, 1], &[10, 5, 6]);
assert_eq!(params, Some((1, 30)));
let params = fast_broadcast_cycles_repeats(&[1, 10, 1], &[5, 10, 6]);
assert_eq!(params, Some((5, 6)));
let params = fast_broadcast_cycles_repeats(&[5, 1, 5], &[5, 6, 5]);
assert_eq!(params, None);
let params = fast_broadcast_cycles_repeats(&[1, 5, 1, 5, 1], &[2, 5, 6, 5, 2]);
assert_eq!(params, None);
let params = fast_broadcast_cycles_repeats(&[10], &[5, 3, 10]);
assert_eq!(params, Some((15, 1)));
}
#[test]
#[should_panic]
fn test_fast_broadcast_cycles_repeats_invalid() {
fast_broadcast_cycles_repeats(&[1, 2, 3], &[1, 2]);
}
#[test]
fn test_add() -> Result<(), Box<dyn Error>> {
let pool = BufferPool::new();
let a = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let b = Tensor::from_data(&[2, 2], vec![10., 20., 30., 40.]);
let expected = Tensor::from_data(&[2, 2], vec![11., 22., 33., 44.]);
let result = add(&pool, a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
let a = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]);
let b = Tensor::from_data(&[2, 2], vec![10, 20, 30, 40]);
let expected = Tensor::from_data(&[2, 2], vec![11, 22, 33, 44]);
let result = add(&pool, a.view(), b.view()).unwrap();
assert_eq!(result, expected);
Ok(())
}
#[test]
fn test_add_broadcasted() -> Result<(), Box<dyn Error>> {
let pool = BufferPool::new();
let a = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let b = Tensor::from([10.]);
let expected = Tensor::from_data(&[2, 2], vec![11., 12., 13., 14.]);
let result = add(&pool, a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
let result = add(&pool, b.view(), a.view()).unwrap();
expect_equal(&result, &expected)?;
let a = Tensor::from([1., 2., 3., 4., 5.]);
let b = Tensor::from_data(&[1, 5], vec![1., 2., 3., 4., 5.]);
let expected = Tensor::from_data(&[1, 5], vec![2., 4., 6., 8., 10.]);
let result = add(&pool, a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
let a = Tensor::from(3.0);
let b = Tensor::from_data(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
let result = add(&pool, a.view(), b.view()).unwrap();
let expected = Tensor::from_data(&[2, 2], vec![4.0, 5.0, 6.0, 7.0]);
expect_equal(&result, &expected)?;
let a = Tensor::from_data(&[2, 1], vec![1, 2]);
let b = Tensor::from_data(&[1, 2], vec![3, 4]);
let result = add(&pool, a.view(), b.view()).unwrap();
assert_eq!(result.shape(), &[2, 2]);
assert_eq!(result.to_vec(), &[4, 5, 5, 6]);
Ok(())
}
#[test]
fn test_add_broadcast_first_input() {
let pool = BufferPool::new();
let a: Tensor<i32> = Tensor::zeros(&[1, 1, 10]);
let b = Tensor::zeros(&[1, 5, 10]);
let result = add(&pool, a.view(), b.view()).unwrap();
assert_eq!(result.shape(), &[1, 5, 10]);
}
#[test]
fn test_add_in_place() -> Result<(), Box<dyn Error>> {
let pool = BufferPool::new();
let mut a = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let a_copy = a.clone();
let b = Tensor::from_data(&[2, 2], vec![10., 20., 30., 40.]);
let expected = Tensor::from_data(&[2, 2], vec![11., 22., 33., 44.]);
add_in_place(a.view_mut(), b.view());
expect_equal(&a, &expected)?;
let mut a_ints = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]);
let b_ints = Tensor::from_data(&[2, 2], vec![10, 20, 30, 40]);
let expected_ints = Tensor::from_data(&[2, 2], vec![11, 22, 33, 44]);
add_in_place(a_ints.view_mut(), b_ints.view());
assert_eq!(&a_ints, &expected_ints);
let op = Add {};
let inputs: InputList = (&b).into();
let ctx = OpRunContext::new(&pool, &inputs);
let result = op.run_in_place(Value::FloatTensor(a_copy), &ctx).unwrap();
expect_equal(&result.as_tensor_view().unwrap(), &expected.view())?;
let scalar = Tensor::from(1.0);
let expected = Tensor::from_data(&[2, 2], vec![11., 21., 31., 41.]);
let result = op.run_in_place(Value::FloatTensor(scalar), &ctx).unwrap();
expect_equal(&result.as_tensor_view().unwrap(), &expected.view())?;
let mut a = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let b = Tensor::from([1., 2.]);
let expected = Tensor::from_data(&[2, 2], vec![2., 4., 4., 6.]);
add_in_place(a.view_mut(), b.view());
expect_equal(&a, &expected)?;
let mut a = Tensor::from_data(&[2, 3], vec![1., 2., 0., 3., 4., 0.]);
a.clip_dim(1, 0..2);
assert!(!a.is_contiguous());
let b = Tensor::from([1., 2.]);
let expected = Tensor::from_data(&[2, 2], vec![2., 4., 4., 6.]);
add_in_place(a.view_mut(), b.view());
expect_equal(&a, &expected)?;
Ok(())
}
#[test]
fn test_add_invalid_broadcast() {
let a = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let b = Tensor::from_data(&[2, 3], vec![1., 2., 3., 4., 5., 6.]);
let op = Add {};
let result: Result<Tensor<f32>, _> = op.run_simple((&a, &b));
assert_eq!(
result.err(),
Some(OpError::IncompatibleInputShapes("Cannot broadcast inputs"))
);
}
#[test]
fn test_and() {
let pool = BufferPool::new();
let a = Tensor::from([0, 1, 0, 1]);
let b = Tensor::from([0, 0, 1, 1]);
let expected = Tensor::from([0, 0, 0, 1]);
let result = and(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
}
#[test]
fn test_div() -> Result<(), Box<dyn Error>> {
let pool = BufferPool::new();
let a = Tensor::from_data(&[2, 2], vec![10., 20., 30., 40.]);
let b = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let expected = Tensor::from_data(&[2, 2], vec![10., 10., 10., 10.]);
let result = div(&pool, a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
let a = Tensor::from_data(&[2, 2], vec![10., 20., 30., 40.]);
let b = Tensor::from(10.);
let expected = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let result = div(&pool, a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
let a = Tensor::from([1, 2, 3, 4]);
let b = Tensor::from([2, 2, 2, 2]);
let expected = Tensor::from([0, 1, 1, 2]);
let result = div(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
let a = Tensor::from([1, 2, 3, 4]);
let b = Tensor::from(2);
let expected = Tensor::from([0, 1, 1, 2]);
let result = div(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
Ok(())
}
#[test]
fn test_div_in_place() -> Result<(), Box<dyn Error>> {
let mut a = Tensor::from_data(&[2, 2], vec![10., 20., 30., 40.]);
let b = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let expected = Tensor::from_data(&[2, 2], vec![10., 10., 10., 10.]);
div_in_place(a.view_mut(), b.view());
expect_equal(&a, &expected)?;
let mut a = Tensor::from_data(&[2, 2], vec![10., 20., 30., 40.]);
let b = Tensor::from(10.);
let expected = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
div_in_place(a.view_mut(), b.view());
expect_equal(&a, &expected)?;
let mut a = Tensor::from([1, 2, 3, 4]);
let b = Tensor::from([2, 2, 2, 2]);
let expected = Tensor::from([0, 1, 1, 2]);
div_in_place(a.view_mut(), b.view());
assert_eq!(&a, &expected);
let mut a = Tensor::from([1, 2, 3, 4]);
let b = Tensor::from(2);
let expected = Tensor::from([0, 1, 1, 2]);
div_in_place(a.view_mut(), b.view());
assert_eq!(&a, &expected);
Ok(())
}
#[test]
fn test_equal() {
let pool = BufferPool::new();
let a = Tensor::from([1, 2]);
let b = Tensor::from([1, 3]);
let expected = Tensor::from([1, 0]);
let result = equal(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
let a = Tensor::from([1., 2.]);
let b = Tensor::from([1., 3.]);
let expected = Tensor::from([1, 0]);
let result = equal(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
}
#[test]
fn test_greater() {
let pool = BufferPool::new();
let a = Tensor::from([1, 2, 5]);
let b = Tensor::from([1, 3, 4]);
let expected = Tensor::from([0, 0, 1]);
let result = greater(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
let a = Tensor::from([1., 2., 5.]);
let b = Tensor::from([1., 3., 4.]);
let expected = Tensor::from([0, 0, 1]);
let result = greater(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
}
#[test]
fn test_greater_or_equal() {
let pool = BufferPool::new();
let a = Tensor::from([1, 2, 5]);
let b = Tensor::from([1, 3, 4]);
let expected = Tensor::from([1, 0, 1]);
let result = greater_or_equal(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
let a = Tensor::from([1., 2., 5.]);
let b = Tensor::from([1., 3., 4.]);
let expected = Tensor::from([1, 0, 1]);
let result = greater_or_equal(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
}
#[test]
fn test_less() {
let pool = BufferPool::new();
let a = Tensor::from([1, 2]);
let b = Tensor::from([1, 3]);
let expected = Tensor::from([0, 1]);
let result = less(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
let a = Tensor::from([1., 2.]);
let b = Tensor::from([1., 3.]);
let expected = Tensor::from([0, 1]);
let result = less(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
}
#[test]
fn test_less_or_equal() {
let pool = BufferPool::new();
let a = Tensor::from([1, 2, 5]);
let b = Tensor::from([1, 3, 4]);
let expected = Tensor::from([1, 1, 0]);
let result = less_or_equal(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
let a = Tensor::from([1., 2., 5.]);
let b = Tensor::from([1., 3., 4.]);
let expected = Tensor::from([1, 1, 0]);
let result = less_or_equal(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
}
#[test]
fn test_mod_op() {
let pool = BufferPool::new();
let a = Tensor::from([10, -10, 10, -10]);
let b = Tensor::from([3, 3, -3, -3]);
let expected = Tensor::from([1, 2, -2, -1]);
let result = mod_op(&pool, a.view(), b.view(), DivMode::FloorDiv).unwrap();
assert_eq!(&result, &expected);
let expected = Tensor::from([1, -1, 1, -1]);
let result = mod_op(&pool, a.view(), b.view(), DivMode::TruncDiv).unwrap();
assert_eq!(&result, &expected);
let af = Tensor::from([3.5, -3.5, 3.5, -3.5]);
let bf = Tensor::from([2.5, 2.5, -2.5, -2.5]);
let expected = Tensor::from([1., 1.5, -1.5, -1.]);
let result = mod_op(&pool, af.view(), bf.view(), DivMode::FloorDiv).unwrap();
assert_eq!(&result, &expected);
let expected = Tensor::from([1., -1., 1., -1.]);
let result = mod_op(&pool, af.view(), bf.view(), DivMode::TruncDiv).unwrap();
assert_eq!(&result, &expected);
}
#[test]
fn test_mul() -> Result<(), Box<dyn Error>> {
let pool = BufferPool::new();
let a = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let b = Tensor::from_data(&[2, 2], vec![10., 20., 30., 40.]);
let expected = Tensor::from_data(&[2, 2], vec![10., 40., 90., 160.]);
let result = mul(&pool, a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
let a = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]);
let b = Tensor::from_data(&[2, 2], vec![10, 20, 30, 40]);
let expected = Tensor::from_data(&[2, 2], vec![10, 40, 90, 160]);
let result = mul(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
Ok(())
}
#[test]
fn test_mul_in_place() -> Result<(), Box<dyn Error>> {
let mut a = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let b = Tensor::from_data(&[2, 2], vec![10., 20., 30., 40.]);
let expected = Tensor::from_data(&[2, 2], vec![10., 40., 90., 160.]);
mul_in_place(a.view_mut(), b.view());
expect_equal(&a, &expected)?;
let mut a = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]);
let b = Tensor::from_data(&[2, 2], vec![10, 20, 30, 40]);
let expected = Tensor::from_data(&[2, 2], vec![10, 40, 90, 160]);
mul_in_place(a.view_mut(), b.view());
assert_eq!(&a, &expected);
Ok(())
}
#[test]
fn test_or() {
let pool = BufferPool::new();
let a = Tensor::from([0, 1, 0, 1]);
let b = Tensor::from([0, 0, 1, 1]);
let expected = Tensor::from([0, 1, 1, 1]);
let result = or(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
}
#[test]
fn test_pow() {
#[derive(Debug)]
struct Case {
a: Value,
b: Value,
expected: Value,
}
fn scalar_case<B, E>(base: B, exp: E, expected: B) -> Case
where
B: Clone + Scalar + Into<Tensor<B>>,
E: Clone + Scalar + Into<Tensor<E>>,
Value: From<Tensor<B>> + From<Tensor<E>>,
{
Case {
a: Tensor::from(base).into(),
b: Tensor::from(exp).into(),
expected: Tensor::from(expected).into(),
}
}
let cases = vec![
Case {
a: Tensor::from([2., 3., 4.]).into(),
b: Tensor::from(2.).into(),
expected: Tensor::from([4., 9., 16.]).into(),
},
Case {
a: Tensor::from([2., 3., 4.]).into(),
b: Tensor::from(3.).into(),
expected: Tensor::from([8., 27., 64.]).into(),
},
Case {
a: Tensor::from([2., 3., 4.]).into(),
b: Tensor::from(0.256).into(),
expected: Tensor::from([
(2f32).powf(0.256),
(3f32).powf(0.256),
(4f32).powf(0.256),
])
.into(),
},
Case {
a: Tensor::from([2., 3., 4.]).into(),
b: Tensor::from([1., 2., 3.]).into(),
expected: Tensor::from([2., 9., 64.]).into(),
},
scalar_case(16i32, 0.5f32, 4i32),
scalar_case(4i32, 2i32, 16i32),
scalar_case(2i32, 3i32, 8i32),
scalar_case(4097i32, 2i32, 4097 * 4097),
scalar_case(i32::MAX.isqrt(), 2i32, i32::MAX.isqrt() * i32::MAX.isqrt()),
{
let x = i32::MAX.isqrt() + 1;
scalar_case(x, 2i32, x.wrapping_mul(x))
},
scalar_case(216i32, 4i32, -2118184960),
scalar_case(2i32, -1i32, 0),
];
cases.test_each(|case| {
let pool = BufferPool::new();
macro_rules! test_case {
($a:expr, $b:expr, $expected:expr) => {
let result = pow(&pool, $a.view(), $b.view()).unwrap();
expect_equal(&result, $expected).unwrap();
let mut a = $a.clone();
pow_in_place(a.view_mut(), $b.view());
expect_equal(&a, $expected).unwrap();
};
}
match (&case.a, &case.b, &case.expected) {
(Value::FloatTensor(a), Value::FloatTensor(b), Value::FloatTensor(expected)) => {
test_case!(a, b, expected);
}
(Value::Int32Tensor(a), Value::FloatTensor(b), Value::Int32Tensor(expected)) => {
test_case!(a, b, expected);
}
(Value::Int32Tensor(a), Value::Int32Tensor(b), Value::Int32Tensor(expected)) => {
test_case!(a, b, expected);
}
_ => unimplemented!("unsupported types"),
}
})
}
#[test]
fn test_sub() -> Result<(), Box<dyn Error>> {
let pool = BufferPool::new();
let a = Tensor::from_data(&[2, 2], vec![10., 20., 30., 40.]);
let b = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let expected = Tensor::from_data(&[2, 2], vec![9., 18., 27., 36.]);
let result = sub(&pool, a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
let a = Tensor::from_data(&[2, 2], vec![10, 20, 30, 40]);
let b = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]);
let expected = Tensor::from_data(&[2, 2], vec![9, 18, 27, 36]);
let result = sub(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
Ok(())
}
#[test]
fn test_sub_broadcast() -> Result<(), Box<dyn Error>> {
let pool = BufferPool::new();
let a = Tensor::from([
[[2, 4, 6], [8, 10, 12], [14, 16, 18]],
[[3, 6, 9], [12, 15, 18], [21, 24, 27]],
]);
let b = Tensor::from([[[1], [2], [3]]]);
let expected = Tensor::from([
[[1, 3, 5], [6, 8, 10], [11, 13, 15]],
[[2, 5, 8], [10, 13, 16], [18, 21, 24]],
]);
let result = sub(&pool, a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
Ok(())
}
#[test]
fn test_sub_in_place() -> Result<(), Box<dyn Error>> {
let mut a = Tensor::from_data(&[2, 2], vec![10., 20., 30., 40.]);
let b = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let expected = Tensor::from_data(&[2, 2], vec![9., 18., 27., 36.]);
sub_in_place(a.view_mut(), b.view());
expect_equal(&a, &expected)?;
let mut a = Tensor::from_data(&[2, 2], vec![10, 20, 30, 40]);
let b = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]);
let expected = Tensor::from_data(&[2, 2], vec![9, 18, 27, 36]);
sub_in_place(a.view_mut(), b.view());
assert_eq!(&a, &expected);
Ok(())
}
#[test]
fn test_where() {
let pool = BufferPool::new();
let cond = Tensor::from_data(&[2, 2], vec![1, 0, 0, 1]);
let x = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let y = Tensor::from_data(&[2, 2], vec![10., 20., 30., 40.]);
let result = where_op(&pool, cond.view(), x.view(), y.view()).unwrap();
let expected = Tensor::from_data(&[2, 2], vec![1., 20., 30., 4.]);
assert_eq!(&result, &expected);
let cond = Tensor::from([1, 1, 0, 0]);
let x = Tensor::from(1.);
let y = Tensor::from(2.);
let result = where_op(&pool, cond.view(), x.view(), y.view()).unwrap();
let expected = Tensor::from([1., 1., 2., 2.]);
assert_eq!(&result, &expected);
let cond = Tensor::from(1);
let x = Tensor::from([1., 2.]);
let y = Tensor::from([3., 4.]);
let result = where_op(&pool, cond.view(), x.view(), y.view()).unwrap();
let expected = Tensor::from([1., 2.]);
assert_eq!(&result, &expected);
let cond = Tensor::from([1, 1, 0, 0]);
let x = Tensor::from(3);
let y = Tensor::from(4);
let result = where_op(&pool, cond.view(), x.view(), y.view()).unwrap();
let expected = Tensor::from([3, 3, 4, 4]);
assert_eq!(&result, &expected);
let cond = Tensor::from([[1, 0], [1, 0]]);
let x = Tensor::from([[1], [2]]);
let y = Tensor::from([[3], [4]]);
let result = where_op(&pool, cond.view(), x.view(), y.view()).unwrap();
let expected = Tensor::from([[1, 3], [2, 4]]);
assert_eq!(&result, &expected);
}
#[test]
fn test_where_invalid_inputs() {
let pool = BufferPool::new();
let cond = Tensor::from([1, 1]);
let x = Tensor::from([1, 2, 3]);
let y = Tensor::from([2, 2]);
let result = where_op(&pool, cond.view(), x.view(), y.view());
assert_eq!(
result.err(),
Some(OpError::IncompatibleInputShapes("Cannot broadcast inputs"))
);
let result = where_op(&pool, cond.view(), y.view(), x.view());
assert_eq!(
result.err(),
Some(OpError::IncompatibleInputShapes("Cannot broadcast inputs"))
);
}
#[test]
fn test_xor() {
let pool = BufferPool::new();
let a = Tensor::from([0, 1, 0, 1]);
let b = Tensor::from([0, 0, 1, 1]);
let expected = Tensor::from([0, 1, 1, 0]);
let result = xor(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
}
}