use smallvec::SmallVec;
use std::fmt::Debug;
use std::iter::{repeat, zip};
use rten_tensor::prelude::*;
use rten_tensor::{Tensor, TensorView, TensorViewMut};
use crate::number::{AsBool, Identities, IsInt};
use crate::ops::{Input, InputList, IntoOpResult, OpError, Operator, Output, OutputList};
use crate::tensor_pool::TensorPool;
pub fn broadcast_shapes(a: &[usize], b: &[usize]) -> Option<SmallVec<[usize; 4]>> {
let a_pad = b.len().saturating_sub(a.len());
let b_pad = a.len().saturating_sub(b.len());
let a_iter = a.iter().copied().rev().chain(repeat(1).take(a_pad));
let b_iter = b.iter().copied().rev().chain(repeat(1).take(b_pad));
let mut result = SmallVec::with_capacity(a.len().max(b.len()));
for (a, b) in zip(a_iter, b_iter) {
if a == b {
result.push(a);
} else if a == 1 {
result.push(b);
} else if b == 1 {
result.push(a);
} else {
return None;
}
}
result.reverse();
Some(result)
}
fn can_run_binary_op_in_place<L1: Layout, L2: Layout>(a: &L1, b: &L2) -> bool {
b.can_broadcast_to(a.shape().as_ref())
}
pub fn fast_broadcast_cycles_repeats(
from_shape: &[usize],
to_shape: &[usize],
) -> Option<(usize, usize)> {
if from_shape == to_shape {
return Some((1, 1));
}
if from_shape.iter().product::<usize>() == 1 {
return Some((1, to_shape.iter().product()));
}
assert!(to_shape.len() >= from_shape.len());
let from_pad = to_shape.len() - from_shape.len();
let from_size = |dim| {
if dim < from_pad {
1
} else {
from_shape[dim - from_pad]
}
};
let mut leading_1s = 0; let mut leading_bcast = 0; let mut cycles = 1;
for (i, to) in to_shape.iter().copied().enumerate() {
let from = from_size(i);
if from == 1 && to == 1 {
leading_1s += 1;
} else if from == 1 && to > 1 {
leading_bcast += 1;
cycles *= to;
} else {
break;
}
}
let mut trailing_1s = 0; let mut trailing_bcast = 0; let mut repeats = 1;
for (i, to) in to_shape.iter().copied().enumerate().rev() {
let from = from_size(i);
if from == 1 && to == 1 {
trailing_1s += 1;
} else if from == 1 && to > 1 {
trailing_bcast += 1;
repeats *= to;
} else {
break;
}
}
for i in (leading_1s + leading_bcast)..(to_shape.len() - trailing_1s - trailing_bcast) {
let from = from_size(i);
let to = to_shape[i];
if from != to {
return None;
}
}
Some((cycles, repeats))
}
fn fast_broadcast_cycles(from_shape: &[usize], to_shape: &[usize]) -> Option<usize> {
if from_shape.iter().product::<usize>() == 1 {
return Some(to_shape.iter().product());
}
fast_broadcast_cycles_repeats(from_shape, to_shape).and_then(|(cycles, repeats)| {
match (cycles, repeats) {
(n, 1) => Some(n),
_ => None,
}
})
}
pub fn binary_op<T: Copy, R, F: Fn(T, T) -> R>(
pool: &TensorPool,
a: TensorView<T>,
b: TensorView<T>,
op: F,
) -> Result<Tensor<R>, OpError> {
let out_shape = broadcast_shapes(a.shape(), b.shape())
.ok_or(OpError::IncompatibleInputShapes("Cannot broadcast inputs"))?;
if let (true, Some(a_data), Some(b_data)) =
(a.shape() == out_shape.as_slice(), a.data(), b.data())
{
if let Some((cycles, repeats)) = fast_broadcast_cycles_repeats(b.shape(), a.shape()) {
assert!(cycles * b_data.len() * repeats == a.len());
let mut output = Tensor::uninit_in(pool, &out_shape);
let out_data = output.data_mut().unwrap();
let a_ptr = a_data.as_ptr();
let mut i = 0;
for _ in 0..cycles {
if repeats == 1 {
for b_elt in b_data {
let (a_elt, out_elt) =
unsafe { (*a_ptr.add(i), out_data.get_unchecked_mut(i)) };
out_elt.write(op(a_elt, *b_elt));
i += 1;
}
} else {
for b_elt in b_data {
for _ in 0..repeats {
let (a_elt, out_elt) =
unsafe { (*a_ptr.add(i), out_data.get_unchecked_mut(i)) };
out_elt.write(op(a_elt, *b_elt));
i += 1;
}
}
}
}
assert!(i == output.len());
let output = unsafe { output.assume_init() };
return Ok(output);
}
}
let mut a = a.broadcast(out_shape.as_slice());
let mut b = b.broadcast(out_shape.as_slice());
let mut out_data = pool.alloc(a.len());
let out_uninit = &mut out_data.spare_capacity_mut()[..a.len()];
let mut out_offset = 0;
while a.ndim() <= 4 {
a.insert_axis(0);
b.insert_axis(0);
}
a.inner_iter::<4>()
.zip(b.inner_iter::<4>())
.for_each(|(a, b)| {
for i0 in 0..a.size(0) {
for i1 in 0..a.size(1) {
for i2 in 0..a.size(2) {
for i3 in 0..a.size(3) {
unsafe {
let a_elt = a.get_unchecked([i0, i1, i2, i3]);
let b_elt = b.get_unchecked([i0, i1, i2, i3]);
out_uninit
.get_unchecked_mut(out_offset)
.write(op(*a_elt, *b_elt));
out_offset += 1;
}
}
}
}
}
});
unsafe {
out_data.set_len(out_offset);
}
Ok(Tensor::from_data(&out_shape, out_data))
}
fn binary_op_in_place<T: Copy + Debug, F: Fn(T, T) -> T>(
mut a: TensorViewMut<T>,
b: TensorView<T>,
op: F,
) {
if let (true, Some(b_data)) = (a.is_contiguous(), b.data()) {
if let Some((cycles, repeats)) = fast_broadcast_cycles_repeats(b.shape(), a.shape()) {
assert!(cycles * b_data.len() * repeats == a.len());
let a_data = a.data_mut().unwrap();
let mut i = 0;
for _ in 0..cycles {
if repeats == 1 {
for b_elt in b_data {
let a_elt = unsafe { a_data.get_unchecked_mut(i) };
*a_elt = op(*a_elt, *b_elt);
i += 1;
}
} else {
for b_elt in b_data {
for _ in 0..repeats {
let a_elt = unsafe { a_data.get_unchecked_mut(i) };
*a_elt = op(*a_elt, *b_elt);
i += 1;
}
}
}
}
return;
}
}
let mut b = b.broadcast(a.shape());
while a.ndim() <= 4 {
a.insert_axis(0);
b.insert_axis(0);
}
a.inner_iter_mut::<4>()
.zip(b.inner_iter::<4>())
.for_each(|(mut a, b)| {
for i0 in 0..a.size(0) {
for i1 in 0..a.size(1) {
for i2 in 0..a.size(2) {
for i3 in 0..a.size(3) {
unsafe {
let a_elt = a.get_unchecked_mut([i0, i1, i2, i3]);
let b_elt = b.get_unchecked([i0, i1, i2, i3]);
*a_elt = op(*a_elt, *b_elt);
}
}
}
}
}
});
}
fn binary_commutative_op<T: Copy + Debug + Default, F: Fn(T, T) -> T>(
pool: &TensorPool,
a: TensorView<T>,
b: TensorView<T>,
op: F,
) -> Result<Tensor<T>, OpError> {
if b.len() > a.len() {
binary_op(pool, b, a, op)
} else {
binary_op(pool, a, b, op)
}
}
macro_rules! run_typed_op {
($pool:expr, $inputs:expr, $op_func:ident) => {{
let a = $inputs.require(0)?;
match a {
Input::FloatTensor(a) => {
let b = $inputs.require_as::<f32>(1)?;
$op_func($pool, a, b).into_op_result()
}
Input::IntTensor(a) => {
let b = $inputs.require_as::<i32>(1)?;
$op_func($pool, a, b).into_op_result()
}
}
}};
($inputs:expr, $op_func:ident) => {
run_typed_op!(&TensorPool::new(), $inputs, $op_func)
};
}
macro_rules! run_typed_op_in_place {
($pool:expr, $input:expr, $other: expr, $in_place_op_func:ident, $op_func:ident) => {{
match $input {
Output::FloatTensor(mut a) => {
let b = $other.require_as::<f32>(0)?;
if can_run_binary_op_in_place(&a, &b) {
$in_place_op_func(a.view_mut(), b);
Ok(a.into())
} else {
$op_func($pool, a.view(), b.view()).map(|t| t.into())
}
}
Output::IntTensor(mut a) => {
let b = $other.require_as::<i32>(0)?;
if can_run_binary_op_in_place(&a, &b) {
$in_place_op_func(a.view_mut(), b.view());
Ok(a.into())
} else {
$op_func($pool, a.view(), b.view()).map(|t| t.into())
}
}
}
}};
}
pub fn add<T: Copy + Debug + Default + std::ops::Add<Output = T>>(
pool: &TensorPool,
a: TensorView<T>,
b: TensorView<T>,
) -> Result<Tensor<T>, OpError> {
binary_commutative_op(pool, a, b, |x, y| x + y)
}
pub fn add_in_place<T: Copy + Debug + std::ops::Add<Output = T>>(
a: TensorViewMut<T>,
b: TensorView<T>,
) {
binary_op_in_place(a, b, |x, y| x + y);
}
#[derive(Debug)]
pub struct Add {}
impl Operator for Add {
fn name(&self) -> &str {
"Add"
}
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
run_typed_op!(pool, inputs, add)
}
fn can_run_in_place(&self) -> bool {
true
}
fn is_commutative(&self) -> bool {
true
}
fn run_in_place(
&self,
pool: &TensorPool,
input: Output,
other: InputList,
) -> Result<Output, OpError> {
run_typed_op_in_place!(pool, input, other, add_in_place, add)
}
}
macro_rules! logical_boolean_op {
($op:ident, $op_fn:ident, $expr:expr) => {
pub fn $op_fn<T: AsBool + Copy + Debug>(
pool: &TensorPool,
a: TensorView<T>,
b: TensorView<T>,
) -> Result<Tensor<i32>, OpError> {
#[allow(clippy::redundant_closure_call)]
binary_op(pool, a, b, |x, y| $expr(x.as_bool(), y.as_bool()).into())
}
#[derive(Debug)]
pub struct $op {}
impl Operator for $op {
fn name(&self) -> &str {
stringify!($op)
}
fn is_commutative(&self) -> bool {
true
}
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
let a: TensorView<i32> = inputs.require_as(0)?;
let b: TensorView<i32> = inputs.require_as(1)?;
$op_fn(pool, a, b).into_op_result()
}
}
};
}
logical_boolean_op!(And, and, |x, y| x && y);
logical_boolean_op!(Or, or, |x, y| x || y);
logical_boolean_op!(Xor, xor, |x, y| x ^ y);
pub fn div<
T: Copy
+ Debug
+ Default
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>
+ IsInt
+ Identities,
>(
pool: &TensorPool,
a: TensorView<T>,
b: TensorView<T>,
) -> Result<Tensor<T>, OpError> {
match (T::is_int(), b.item()) {
(false, Some(scalar)) => mul(pool, a, Tensor::from_scalar(T::one() / *scalar).view()),
_ => binary_op(pool, a, b, |x, y| x / y),
}
}
pub fn div_in_place<
T: Copy + Debug + std::ops::Mul<Output = T> + std::ops::Div<Output = T> + IsInt + Identities,
>(
a: TensorViewMut<T>,
b: TensorView<T>,
) {
match (T::is_int(), b.item()) {
(false, Some(scalar)) => mul_in_place(a, Tensor::from_scalar(T::one() / *scalar).view()),
_ => binary_op_in_place(a, b, |x, y| x / y),
}
}
#[derive(Debug)]
pub struct Div {}
impl Operator for Div {
fn name(&self) -> &str {
"Div"
}
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
run_typed_op!(pool, inputs, div)
}
fn can_run_in_place(&self) -> bool {
true
}
fn run_in_place(
&self,
pool: &TensorPool,
input: Output,
other: InputList,
) -> Result<Output, OpError> {
run_typed_op_in_place!(pool, input, other, div_in_place, div)
}
}
enum BooleanOp {
Equal,
Less,
LessOrEqual,
Greater,
GreaterOrEqual,
}
fn boolean_op<T: Copy + Debug + PartialEq + PartialOrd>(
pool: &TensorPool,
a: TensorView<T>,
b: TensorView<T>,
op: BooleanOp,
) -> Result<Tensor<i32>, OpError> {
binary_op(pool, a, b, |x, y| {
i32::from(match op {
BooleanOp::Equal => x == y,
BooleanOp::Less => x < y,
BooleanOp::LessOrEqual => x <= y,
BooleanOp::Greater => x > y,
BooleanOp::GreaterOrEqual => x >= y,
})
})
}
macro_rules! boolean_cmp_op {
($name:ident, $func:ident) => {
pub fn $func<T: Copy + Debug + PartialEq + PartialOrd>(
pool: &TensorPool,
a: TensorView<T>,
b: TensorView<T>,
) -> Result<Tensor<i32>, OpError> {
boolean_op(pool, a, b, BooleanOp::$name)
}
#[derive(Debug)]
pub struct $name {}
impl Operator for $name {
fn name(&self) -> &str {
stringify!($name)
}
fn is_commutative(&self) -> bool {
stringify!($name) == "Equal"
}
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
run_typed_op!(pool, inputs, $func)
}
}
};
}
boolean_cmp_op!(Equal, equal);
boolean_cmp_op!(Greater, greater);
boolean_cmp_op!(GreaterOrEqual, greater_or_equal);
boolean_cmp_op!(Less, less);
boolean_cmp_op!(LessOrEqual, less_or_equal);
fn rem_floor<
T: Copy + Default + PartialOrd + std::ops::Add<Output = T> + std::ops::Rem<Output = T>,
>(
x: T,
y: T,
) -> T {
let zero = T::default();
let mut rem = x % y;
if rem > zero && y < zero || rem < zero && y > zero {
rem = rem + y;
}
rem
}
pub enum DivMode {
FloorDiv,
TruncDiv,
}
pub fn mod_op<
T: Copy + Debug + Default + PartialOrd + std::ops::Add<Output = T> + std::ops::Rem<Output = T>,
>(
pool: &TensorPool,
a: TensorView<T>,
b: TensorView<T>,
mode: DivMode,
) -> Result<Tensor<T>, OpError> {
binary_op(
pool,
a,
b,
match mode {
DivMode::FloorDiv => |x, y| rem_floor(x, y),
DivMode::TruncDiv => |x, y| x % y,
},
)
}
#[derive(Debug)]
pub struct Mod {
pub fmod: bool,
}
impl Operator for Mod {
fn name(&self) -> &str {
"Mod"
}
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
let a = inputs.require(0)?;
let mode = if self.fmod {
DivMode::TruncDiv
} else {
DivMode::FloorDiv
};
match a {
Input::FloatTensor(a) => {
let b = inputs.require_as::<f32>(1)?;
mod_op(pool, a, b, mode).into_op_result()
}
Input::IntTensor(a) => {
let b = inputs.require_as::<i32>(1)?;
mod_op(pool, a, b, mode).into_op_result()
}
}
}
}
pub fn mul<T: Copy + Debug + Default + std::ops::Mul<Output = T>>(
pool: &TensorPool,
a: TensorView<T>,
b: TensorView<T>,
) -> Result<Tensor<T>, OpError> {
binary_commutative_op(pool, a, b, |x, y| x * y)
}
pub fn mul_in_place<T: Copy + Debug + std::ops::Mul<Output = T>>(
a: TensorViewMut<T>,
b: TensorView<T>,
) {
binary_op_in_place(a, b, |a_elt, b_elt| a_elt * b_elt);
}
#[derive(Debug)]
pub struct Mul {}
impl Operator for Mul {
fn name(&self) -> &str {
"Mul"
}
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
run_typed_op!(pool, inputs, mul)
}
fn can_run_in_place(&self) -> bool {
true
}
fn is_commutative(&self) -> bool {
true
}
fn run_in_place(
&self,
pool: &TensorPool,
input: Output,
other: InputList,
) -> Result<Output, OpError> {
run_typed_op_in_place!(pool, input, other, mul_in_place, mul)
}
}
fn powf(x: f32, y: f32) -> f32 {
if y == 2. {
x * x
} else if y == 3. {
x * x * x
} else {
x.powf(y)
}
}
pub fn pow(pool: &TensorPool, a: TensorView, b: TensorView) -> Result<Tensor, OpError> {
if let Some(&exp) = b.item() {
Ok(a.map_in(pool, |x| powf(*x, exp)))
} else {
binary_op(pool, a, b, powf)
}
}
pub fn pow_in_place(mut a: TensorViewMut, b: TensorView) {
if let Some(exp) = b.item() {
a.apply(|x| powf(*x, *exp))
} else {
binary_op_in_place(a, b, powf);
}
}
#[derive(Debug)]
pub struct Pow {}
impl Operator for Pow {
fn name(&self) -> &str {
"Pow"
}
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
let a = inputs.require_as(0)?;
let b = inputs.require_as(1)?;
pow(pool, a, b).into_op_result()
}
fn can_run_in_place(&self) -> bool {
true
}
fn run_in_place(
&self,
pool: &TensorPool,
input: Output,
other: InputList,
) -> Result<Output, OpError> {
let mut a = input.into_float().ok_or(OpError::IncorrectInputType)?;
let b = other.require_as(0)?;
if can_run_binary_op_in_place(&a, &b) {
pow_in_place(a.view_mut(), b);
Ok(a.into())
} else {
pow(pool, a.view(), b).map(|t| t.into())
}
}
}
pub fn sub<T: Copy + Debug + Default + std::ops::Sub<Output = T>>(
pool: &TensorPool,
a: TensorView<T>,
b: TensorView<T>,
) -> Result<Tensor<T>, OpError> {
binary_op(pool, a, b, |x, y| x - y)
}
pub fn sub_in_place<T: Copy + Debug + std::ops::Sub<Output = T>>(
a: TensorViewMut<T>,
b: TensorView<T>,
) {
binary_op_in_place(a, b, |x, y| x - y);
}
#[derive(Debug)]
pub struct Sub {}
impl Operator for Sub {
fn name(&self) -> &str {
"Sub"
}
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
run_typed_op!(pool, inputs, sub)
}
fn can_run_in_place(&self) -> bool {
true
}
fn run_in_place(
&self,
pool: &TensorPool,
input: Output,
other: InputList,
) -> Result<Output, OpError> {
run_typed_op_in_place!(pool, input, other, sub_in_place, sub)
}
}
pub fn where_op<T: Copy>(
pool: &TensorPool,
cond: TensorView<i32>,
x: TensorView<T>,
y: TensorView<T>,
) -> Result<Tensor<T>, OpError> {
let broadcast_xy_shape = broadcast_shapes(x.shape(), y.shape())
.ok_or(OpError::IncompatibleInputShapes("Cannot broadcast inputs"))?;
let result_shape = broadcast_shapes(cond.shape(), &broadcast_xy_shape)
.ok_or(OpError::IncompatibleInputShapes("Cannot broadcast inputs"))?;
let cond_cycles = fast_broadcast_cycles(cond.shape(), &result_shape);
let x_cycles = fast_broadcast_cycles(x.shape(), &result_shape);
let y_cycles = fast_broadcast_cycles(y.shape(), &result_shape);
let can_cycle = cond_cycles.is_some() && x_cycles.is_some() && y_cycles.is_some();
let out_len = result_shape.iter().product();
let mut out_data = pool.alloc(out_len);
if let (true, Some(cond_data), Some(x_data), Some(y_data)) =
(can_cycle, cond.data(), x.data(), y.data())
{
out_data.extend(
cond_data
.iter()
.cycle()
.zip(x_data.iter().cycle())
.zip(y_data.iter().cycle())
.take(out_len)
.map(|((&cond, &x), &y)| if cond != 0 { x } else { y }),
);
} else {
let mut cond = cond.broadcast(result_shape.as_slice());
let mut x = x.broadcast(result_shape.as_slice());
let mut y = y.broadcast(result_shape.as_slice());
while cond.ndim() <= 4 {
cond.insert_axis(0);
x.insert_axis(0);
y.insert_axis(0);
}
let out_uninit = &mut out_data.spare_capacity_mut()[..cond.len()];
let mut out_offset = 0;
cond.inner_iter::<4>()
.zip(x.inner_iter::<4>().zip(y.inner_iter::<4>()))
.for_each(|(cond, (x, y))| {
for i0 in 0..cond.size(0) {
for i1 in 0..cond.size(1) {
for i2 in 0..cond.size(2) {
for i3 in 0..cond.size(3) {
unsafe {
let cond_elt = *cond.get_unchecked([i0, i1, i2, i3]);
let x_elt = *x.get_unchecked([i0, i1, i2, i3]);
let y_elt = *y.get_unchecked([i0, i1, i2, i3]);
let out_elt = if cond_elt != 0 { x_elt } else { y_elt };
out_uninit.get_unchecked_mut(out_offset).write(out_elt);
out_offset += 1;
}
}
}
}
}
});
assert!(out_offset == cond.len());
unsafe {
out_data.set_len(cond.len());
}
}
Ok(Tensor::from_data(&result_shape, out_data))
}
#[derive(Debug)]
pub struct Where {}
impl Operator for Where {
fn name(&self) -> &str {
"Where"
}
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
let condition = inputs.require_as::<i32>(0)?;
let x = inputs.require(1)?;
let y = inputs.require(2)?;
match x {
Input::FloatTensor(x) => {
let y: TensorView = y.try_into()?;
where_op(pool, condition, x, y).into_op_result()
}
Input::IntTensor(x) => {
let y: TensorView<i32> = y.try_into()?;
where_op(pool, condition, x, y).into_op_result()
}
}
}
}
#[cfg(test)]
mod tests {
use std::error::Error;
use rten_tensor::prelude::*;
use rten_tensor::test_util::expect_equal;
use rten_tensor::Tensor;
use super::{fast_broadcast_cycles, fast_broadcast_cycles_repeats};
use crate::ops::tests::new_pool;
use crate::ops::{
add, add_in_place, and, div, div_in_place, equal, greater, greater_or_equal, less,
less_or_equal, mod_op, mul, mul_in_place, or, pow, pow_in_place, sub, sub_in_place,
where_op, xor, Add, DivMode, OpError, Operator, Output,
};
#[test]
fn test_fast_broadcast_cycles_repeats() {
let params = fast_broadcast_cycles_repeats(&[], &[1, 2, 3]);
assert_eq!(params, Some((1, 6)));
let params = fast_broadcast_cycles_repeats(&[1, 1, 1], &[5, 6, 2]);
assert_eq!(params, Some((1, 60)));
let params = fast_broadcast_cycles_repeats(&[3, 4, 5], &[3, 4, 5]);
assert_eq!(params, Some((1, 1)));
let params = fast_broadcast_cycles_repeats(&[1, 1, 10], &[5, 2, 10]);
assert_eq!(params, Some((10, 1)));
let params = fast_broadcast_cycles_repeats(&[10, 1, 1], &[10, 5, 6]);
assert_eq!(params, Some((1, 30)));
let params = fast_broadcast_cycles_repeats(&[1, 10, 1], &[5, 10, 6]);
assert_eq!(params, Some((5, 6)));
let params = fast_broadcast_cycles_repeats(&[5, 1, 5], &[5, 6, 5]);
assert_eq!(params, None);
let params = fast_broadcast_cycles_repeats(&[1, 5, 1, 5, 1], &[2, 5, 6, 5, 2]);
assert_eq!(params, None);
let params = fast_broadcast_cycles_repeats(&[10], &[5, 3, 10]);
assert_eq!(params, Some((15, 1)));
}
#[test]
fn test_fast_broadcast_cycles() {
let params = fast_broadcast_cycles(&[], &[1, 2, 3]);
assert_eq!(params, Some(6));
let params = fast_broadcast_cycles(&[1, 1, 1], &[5, 6, 2]);
assert_eq!(params, Some(60));
let params = fast_broadcast_cycles(&[3, 4, 5], &[3, 4, 5]);
assert_eq!(params, Some(1));
let params = fast_broadcast_cycles(&[1, 1, 10], &[5, 2, 10]);
assert_eq!(params, Some(10));
let params = fast_broadcast_cycles(&[10, 1, 1], &[10, 5, 6]);
assert_eq!(params, None);
let params = fast_broadcast_cycles(&[1, 10, 1], &[5, 10, 6]);
assert_eq!(params, None);
let params = fast_broadcast_cycles(&[5, 1, 5], &[5, 6, 5]);
assert_eq!(params, None);
let params = fast_broadcast_cycles(&[1, 5, 1, 5, 1], &[2, 5, 6, 5, 2]);
assert_eq!(params, None);
let params = fast_broadcast_cycles(&[10], &[5, 3, 10]);
assert_eq!(params, Some(15));
}
#[test]
#[should_panic]
fn test_fast_broadcast_cycles_repeats_invalid() {
fast_broadcast_cycles_repeats(&[1, 2, 3], &[1, 2]);
}
#[test]
fn test_add() -> Result<(), Box<dyn Error>> {
let pool = new_pool();
let a = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let b = Tensor::from_data(&[2, 2], vec![10., 20., 30., 40.]);
let expected = Tensor::from_data(&[2, 2], vec![11., 22., 33., 44.]);
let result = add(&pool, a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
let a = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]);
let b = Tensor::from_data(&[2, 2], vec![10, 20, 30, 40]);
let expected = Tensor::from_data(&[2, 2], vec![11, 22, 33, 44]);
let result = add(&pool, a.view(), b.view()).unwrap();
assert_eq!(result, expected);
Ok(())
}
#[test]
fn test_add_broadcasted() -> Result<(), Box<dyn Error>> {
let pool = new_pool();
let a = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let b = Tensor::from_data(&[1], vec![10.]);
let expected = Tensor::from_data(&[2, 2], vec![11., 12., 13., 14.]);
let result = add(&pool, a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
let result = add(&pool, b.view(), a.view()).unwrap();
expect_equal(&result, &expected)?;
let a = Tensor::from_data(&[5], vec![1., 2., 3., 4., 5.]);
let b = Tensor::from_data(&[1, 5], vec![1., 2., 3., 4., 5.]);
let expected = Tensor::from_data(&[1, 5], vec![2., 4., 6., 8., 10.]);
let result = add(&pool, a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
let a = Tensor::from(3.0);
let b = Tensor::from_data(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
let result = add(&pool, a.view(), b.view()).unwrap();
let expected = Tensor::from_data(&[2, 2], vec![4.0, 5.0, 6.0, 7.0]);
expect_equal(&result, &expected)?;
let a = Tensor::from_data(&[2, 1], vec![1, 2]);
let b = Tensor::from_data(&[1, 2], vec![3, 4]);
let result = add(&pool, a.view(), b.view()).unwrap();
assert_eq!(result.shape(), &[2, 2]);
assert_eq!(result.to_vec(), &[4, 5, 5, 6]);
Ok(())
}
#[test]
fn test_add_broadcast_first_input() {
let pool = new_pool();
let a: Tensor<i32> = Tensor::zeros(&[1, 1, 10]);
let b = Tensor::zeros(&[1, 5, 10]);
let result = add(&pool, a.view(), b.view()).unwrap();
assert_eq!(result.shape(), &[1, 5, 10]);
}
#[test]
fn test_add_in_place() -> Result<(), Box<dyn Error>> {
let pool = new_pool();
let mut a = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let a_copy = a.clone();
let b = Tensor::from_data(&[2, 2], vec![10., 20., 30., 40.]);
let expected = Tensor::from_data(&[2, 2], vec![11., 22., 33., 44.]);
add_in_place(a.view_mut(), b.view());
expect_equal(&a, &expected)?;
let mut a_ints = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]);
let b_ints = Tensor::from_data(&[2, 2], vec![10, 20, 30, 40]);
let expected_ints = Tensor::from_data(&[2, 2], vec![11, 22, 33, 44]);
add_in_place(a_ints.view_mut(), b_ints.view());
assert_eq!(&a_ints, &expected_ints);
let op = Add {};
let result = op
.run_in_place(&pool, Output::FloatTensor(a_copy), (&b).into())
.unwrap();
expect_equal(result.as_float_ref().unwrap(), &expected)?;
let scalar = Tensor::from(1.0);
let expected = Tensor::from_data(&[2, 2], vec![11., 21., 31., 41.]);
let result = op
.run_in_place(&pool, Output::FloatTensor(scalar), (&b).into())
.unwrap();
expect_equal(result.as_float_ref().unwrap(), &expected)?;
let mut a = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let b = Tensor::from([1., 2.]);
let expected = Tensor::from_data(&[2, 2], vec![2., 4., 4., 6.]);
add_in_place(a.view_mut(), b.view());
expect_equal(&a, &expected)?;
let mut a = Tensor::from_data(&[2, 3], vec![1., 2., 0., 3., 4., 0.]);
a.clip_dim(1, 0..2);
assert!(!a.is_contiguous());
let b = Tensor::from([1., 2.]);
let expected = Tensor::from_data(&[2, 2], vec![2., 4., 4., 6.]);
add_in_place(a.view_mut(), b.view());
expect_equal(&a, &expected)?;
Ok(())
}
#[test]
fn test_add_invalid_broadcast() {
let pool = new_pool();
let a = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let b = Tensor::from_data(&[2, 3], vec![1., 2., 3., 4., 5., 6.]);
let op = Add {};
let result = op.run(&pool, (&a, &b).into());
assert_eq!(
result.err(),
Some(OpError::IncompatibleInputShapes("Cannot broadcast inputs"))
);
}
#[test]
fn test_and() {
let pool = new_pool();
let a = Tensor::from([0, 1, 0, 1]);
let b = Tensor::from([0, 0, 1, 1]);
let expected = Tensor::from([0, 0, 0, 1]);
let result = and(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
}
#[test]
fn test_div() -> Result<(), Box<dyn Error>> {
let pool = new_pool();
let a = Tensor::from_data(&[2, 2], vec![10., 20., 30., 40.]);
let b = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let expected = Tensor::from_data(&[2, 2], vec![10., 10., 10., 10.]);
let result = div(&pool, a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
let a = Tensor::from_data(&[2, 2], vec![10., 20., 30., 40.]);
let b = Tensor::from(10.);
let expected = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let result = div(&pool, a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
let a = Tensor::from([1, 2, 3, 4]);
let b = Tensor::from([2, 2, 2, 2]);
let expected = Tensor::from([0, 1, 1, 2]);
let result = div(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
let a = Tensor::from([1, 2, 3, 4]);
let b = Tensor::from(2);
let expected = Tensor::from([0, 1, 1, 2]);
let result = div(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
Ok(())
}
#[test]
fn test_div_in_place() -> Result<(), Box<dyn Error>> {
let mut a = Tensor::from_data(&[2, 2], vec![10., 20., 30., 40.]);
let b = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let expected = Tensor::from_data(&[2, 2], vec![10., 10., 10., 10.]);
div_in_place(a.view_mut(), b.view());
expect_equal(&a, &expected)?;
let mut a = Tensor::from_data(&[2, 2], vec![10., 20., 30., 40.]);
let b = Tensor::from(10.);
let expected = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
div_in_place(a.view_mut(), b.view());
expect_equal(&a, &expected)?;
let mut a = Tensor::from([1, 2, 3, 4]);
let b = Tensor::from([2, 2, 2, 2]);
let expected = Tensor::from([0, 1, 1, 2]);
div_in_place(a.view_mut(), b.view());
assert_eq!(&a, &expected);
let mut a = Tensor::from([1, 2, 3, 4]);
let b = Tensor::from(2);
let expected = Tensor::from([0, 1, 1, 2]);
div_in_place(a.view_mut(), b.view());
assert_eq!(&a, &expected);
Ok(())
}
#[test]
fn test_equal() {
let pool = new_pool();
let a = Tensor::from([1, 2]);
let b = Tensor::from([1, 3]);
let expected = Tensor::from([1, 0]);
let result = equal(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
let a = Tensor::from([1., 2.]);
let b = Tensor::from([1., 3.]);
let expected = Tensor::from([1, 0]);
let result = equal(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
}
#[test]
fn test_greater() {
let pool = new_pool();
let a = Tensor::from([1, 2, 5]);
let b = Tensor::from([1, 3, 4]);
let expected = Tensor::from([0, 0, 1]);
let result = greater(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
let a = Tensor::from([1., 2., 5.]);
let b = Tensor::from([1., 3., 4.]);
let expected = Tensor::from([0, 0, 1]);
let result = greater(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
}
#[test]
fn test_greater_or_equal() {
let pool = new_pool();
let a = Tensor::from([1, 2, 5]);
let b = Tensor::from([1, 3, 4]);
let expected = Tensor::from([1, 0, 1]);
let result = greater_or_equal(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
let a = Tensor::from([1., 2., 5.]);
let b = Tensor::from([1., 3., 4.]);
let expected = Tensor::from([1, 0, 1]);
let result = greater_or_equal(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
}
#[test]
fn test_less() {
let pool = new_pool();
let a = Tensor::from([1, 2]);
let b = Tensor::from([1, 3]);
let expected = Tensor::from([0, 1]);
let result = less(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
let a = Tensor::from([1., 2.]);
let b = Tensor::from([1., 3.]);
let expected = Tensor::from([0, 1]);
let result = less(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
}
#[test]
fn test_less_or_equal() {
let pool = new_pool();
let a = Tensor::from([1, 2, 5]);
let b = Tensor::from([1, 3, 4]);
let expected = Tensor::from([1, 1, 0]);
let result = less_or_equal(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
let a = Tensor::from([1., 2., 5.]);
let b = Tensor::from([1., 3., 4.]);
let expected = Tensor::from([1, 1, 0]);
let result = less_or_equal(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
}
#[test]
fn test_mod_op() {
let pool = new_pool();
let a = Tensor::from([10, -10, 10, -10]);
let b = Tensor::from([3, 3, -3, -3]);
let expected = Tensor::from([1, 2, -2, -1]);
let result = mod_op(&pool, a.view(), b.view(), DivMode::FloorDiv).unwrap();
assert_eq!(&result, &expected);
let expected = Tensor::from([1, -1, 1, -1]);
let result = mod_op(&pool, a.view(), b.view(), DivMode::TruncDiv).unwrap();
assert_eq!(&result, &expected);
let af = Tensor::from([3.5, -3.5, 3.5, -3.5]);
let bf = Tensor::from([2.5, 2.5, -2.5, -2.5]);
let expected = Tensor::from([1., 1.5, -1.5, -1.]);
let result = mod_op(&pool, af.view(), bf.view(), DivMode::FloorDiv).unwrap();
assert_eq!(&result, &expected);
let expected = Tensor::from([1., -1., 1., -1.]);
let result = mod_op(&pool, af.view(), bf.view(), DivMode::TruncDiv).unwrap();
assert_eq!(&result, &expected);
}
#[test]
fn test_mul() -> Result<(), Box<dyn Error>> {
let pool = new_pool();
let a = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let b = Tensor::from_data(&[2, 2], vec![10., 20., 30., 40.]);
let expected = Tensor::from_data(&[2, 2], vec![10., 40., 90., 160.]);
let result = mul(&pool, a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
let a = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]);
let b = Tensor::from_data(&[2, 2], vec![10, 20, 30, 40]);
let expected = Tensor::from_data(&[2, 2], vec![10, 40, 90, 160]);
let result = mul(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
Ok(())
}
#[test]
fn test_mul_in_place() -> Result<(), Box<dyn Error>> {
let mut a = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let b = Tensor::from_data(&[2, 2], vec![10., 20., 30., 40.]);
let expected = Tensor::from_data(&[2, 2], vec![10., 40., 90., 160.]);
mul_in_place(a.view_mut(), b.view());
expect_equal(&a, &expected)?;
let mut a = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]);
let b = Tensor::from_data(&[2, 2], vec![10, 20, 30, 40]);
let expected = Tensor::from_data(&[2, 2], vec![10, 40, 90, 160]);
mul_in_place(a.view_mut(), b.view());
assert_eq!(&a, &expected);
Ok(())
}
#[test]
fn test_or() {
let pool = new_pool();
let a = Tensor::from([0, 1, 0, 1]);
let b = Tensor::from([0, 0, 1, 1]);
let expected = Tensor::from([0, 1, 1, 1]);
let result = or(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
}
#[test]
fn test_pow() -> Result<(), Box<dyn Error>> {
struct Case {
a: Tensor<f32>,
b: Tensor<f32>,
expected: Tensor<f32>,
}
let cases = [
Case {
a: [2., 3., 4.].into(),
b: (2.).into(),
expected: [4., 9., 16.].into(),
},
Case {
a: [2., 3., 4.].into(),
b: (3.).into(),
expected: [8., 27., 64.].into(),
},
Case {
a: [2., 3., 4.].into(),
b: (0.256).into(),
expected: [(2f32).powf(0.256), (3f32).powf(0.256), (4f32).powf(0.256)].into(),
},
Case {
a: [2., 3., 4.].into(),
b: [1., 2., 3.].into(),
expected: [2., 9., 64.].into(),
},
];
let pool = new_pool();
for case in cases {
let result = pow(&pool, case.a.view(), case.b.view()).unwrap();
expect_equal(&result, &case.expected)?;
let mut a = case.a.clone();
pow_in_place(a.view_mut(), case.b.view());
expect_equal(&a, &case.expected)?;
}
Ok(())
}
#[test]
fn test_sub() -> Result<(), Box<dyn Error>> {
let pool = new_pool();
let a = Tensor::from_data(&[2, 2], vec![10., 20., 30., 40.]);
let b = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let expected = Tensor::from_data(&[2, 2], vec![9., 18., 27., 36.]);
let result = sub(&pool, a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
let a = Tensor::from_data(&[2, 2], vec![10, 20, 30, 40]);
let b = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]);
let expected = Tensor::from_data(&[2, 2], vec![9, 18, 27, 36]);
let result = sub(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
Ok(())
}
#[test]
fn test_sub_broadcast() -> Result<(), Box<dyn Error>> {
let pool = new_pool();
let a = Tensor::from([
[[2, 4, 6], [8, 10, 12], [14, 16, 18]],
[[3, 6, 9], [12, 15, 18], [21, 24, 27]],
]);
let b = Tensor::from([[[1], [2], [3]]]);
let expected = Tensor::from([
[[1, 3, 5], [6, 8, 10], [11, 13, 15]],
[[2, 5, 8], [10, 13, 16], [18, 21, 24]],
]);
let result = sub(&pool, a.view(), b.view()).unwrap();
expect_equal(&result, &expected)?;
Ok(())
}
#[test]
fn test_sub_in_place() -> Result<(), Box<dyn Error>> {
let mut a = Tensor::from_data(&[2, 2], vec![10., 20., 30., 40.]);
let b = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let expected = Tensor::from_data(&[2, 2], vec![9., 18., 27., 36.]);
sub_in_place(a.view_mut(), b.view());
expect_equal(&a, &expected)?;
let mut a = Tensor::from_data(&[2, 2], vec![10, 20, 30, 40]);
let b = Tensor::from_data(&[2, 2], vec![1, 2, 3, 4]);
let expected = Tensor::from_data(&[2, 2], vec![9, 18, 27, 36]);
sub_in_place(a.view_mut(), b.view());
assert_eq!(&a, &expected);
Ok(())
}
#[test]
fn test_where() {
let pool = new_pool();
let cond = Tensor::from_data(&[2, 2], vec![1, 0, 0, 1]);
let x = Tensor::from_data(&[2, 2], vec![1., 2., 3., 4.]);
let y = Tensor::from_data(&[2, 2], vec![10., 20., 30., 40.]);
let result = where_op(&pool, cond.view(), x.view(), y.view()).unwrap();
let expected = Tensor::from_data(&[2, 2], vec![1., 20., 30., 4.]);
assert_eq!(&result, &expected);
let cond = Tensor::from([1, 1, 0, 0]);
let x = Tensor::from(1.);
let y = Tensor::from(2.);
let result = where_op(&pool, cond.view(), x.view(), y.view()).unwrap();
let expected = Tensor::from([1., 1., 2., 2.]);
assert_eq!(&result, &expected);
let cond = Tensor::from(1);
let x = Tensor::from([1., 2.]);
let y = Tensor::from([3., 4.]);
let result = where_op(&pool, cond.view(), x.view(), y.view()).unwrap();
let expected = Tensor::from([1., 2.]);
assert_eq!(&result, &expected);
let cond = Tensor::from([1, 1, 0, 0]);
let x = Tensor::from(3);
let y = Tensor::from(4);
let result = where_op(&pool, cond.view(), x.view(), y.view()).unwrap();
let expected = Tensor::from([3, 3, 4, 4]);
assert_eq!(&result, &expected);
let cond = Tensor::from([[1, 0], [1, 0]]);
let x = Tensor::from([[1], [2]]);
let y = Tensor::from([[3], [4]]);
let result = where_op(&pool, cond.view(), x.view(), y.view()).unwrap();
let expected = Tensor::from([[1, 3], [2, 4]]);
assert_eq!(&result, &expected);
}
#[test]
fn test_where_invalid_inputs() {
let pool = new_pool();
let cond = Tensor::from([1, 1]);
let x = Tensor::from([1, 2, 3]);
let y = Tensor::from([2, 2]);
let result = where_op(&pool, cond.view(), x.view(), y.view());
assert_eq!(
result.err(),
Some(OpError::IncompatibleInputShapes("Cannot broadcast inputs"))
);
let result = where_op(&pool, cond.view(), y.view(), x.view());
assert_eq!(
result.err(),
Some(OpError::IncompatibleInputShapes("Cannot broadcast inputs"))
);
}
#[test]
fn test_xor() {
let pool = new_pool();
let a = Tensor::from([0, 1, 0, 1]);
let b = Tensor::from([0, 0, 1, 1]);
let expected = Tensor::from([0, 1, 1, 0]);
let result = xor(&pool, a.view(), b.view()).unwrap();
assert_eq!(&result, &expected);
}
}