use rten_shape_inference::ops as shape_ops;
use rten_tensor::layout::is_valid_permutation;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensorView, Tensor, TensorView};
use smallvec::SmallVec;
use crate::buffer_pool::{AutoReturn, BufferPool};
use crate::infer_shapes::{InferShapes, impl_infer_shapes};
use crate::operator::{
IntoOpResult, OpError, OpRunContext, Operator, OutputList, OutputType, OutputTypeList,
OutputTypesContext, static_dims,
};
use crate::ops::binary_elementwise::{broadcast_shapes, fast_broadcast_cycles_repeats};
use crate::ops::{map_value, map_value_view, resolve_axes, resolve_axis};
use crate::value::{DataType, Value, ValueType, ValueView};
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum DepthToSpaceMode {
DepthColumnRow,
ColumnRowDepth,
}
pub fn depth_to_space<T: Clone>(
pool: &BufferPool,
input: TensorView<T>,
block_size: u32,
mode: DepthToSpaceMode,
) -> Result<Tensor<T>, OpError> {
if block_size == 0 {
return Err(OpError::InvalidValue("`block_size` must be > 0"));
}
let input = static_dims!(input, 4, "NCHW")?;
let [n, c, h, w] = input.shape();
let block_size = block_size as usize;
if c % (block_size * block_size) != 0 {
return Err(OpError::InvalidValue(
"input channels must be a multiple of `block_size` squared",
));
}
let new_c = c / (block_size * block_size);
let new_shape = [n, new_c, h * block_size, w * block_size];
let tmp = input.to_contiguous_in(pool);
let tmp = match mode {
DepthToSpaceMode::DepthColumnRow => tmp.reshaped([n, block_size, block_size, new_c, h, w]),
DepthToSpaceMode::ColumnRowDepth => tmp.reshaped([n, new_c, block_size, block_size, h, w]),
};
let tmp = match mode {
DepthToSpaceMode::DepthColumnRow => tmp.permuted([0, 3, 4, 1, 5, 2]),
DepthToSpaceMode::ColumnRowDepth => tmp.permuted([0, 1, 4, 2, 5, 3]),
};
let mut tmp = tmp.to_tensor_in(pool).into_dyn();
tmp.reshape(&new_shape);
Ok(tmp)
}
#[derive(Debug)]
pub struct DepthToSpace {
pub block_size: u32,
pub mode: DepthToSpaceMode,
}
impl Operator for DepthToSpace {
fn name(&self) -> &str {
"DepthToSpace"
}
fn max_inputs(&self) -> Option<usize> {
Some(1)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let input = ctx.inputs().require_as(0)?;
depth_to_space::<f32>(ctx.pool(), input, self.block_size, self.mode).into_op_result()
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
}
fn expand_output_shape(
input_shape: &[usize],
shape: &NdTensorView<i32, 1>,
) -> Result<SmallVec<[usize; 4]>, OpError> {
let shape_vec: SmallVec<[usize; 4]> = shape.iter().map(|el| *el as usize).collect();
broadcast_shapes(input_shape, &shape_vec).ok_or(OpError::IncompatibleInputShapes(
"Cannot broadcast input with target shape",
))
}
pub(crate) fn expand_to<T: Copy>(
pool: &BufferPool,
input: TensorView<T>,
out_shape: &[usize],
) -> Tensor<T> {
let out_len = out_shape.iter().product();
match (
input.data(),
fast_broadcast_cycles_repeats(input.shape(), out_shape),
) {
(Some(in_data), Some((cycles, repeats))) => {
assert!(out_len == input.len() * cycles * repeats);
let mut out_data: Vec<T> = pool.alloc(out_len);
let mut out_ptr = out_data.as_mut_ptr();
for _ in 0..cycles {
if repeats == 1 {
unsafe {
std::ptr::copy_nonoverlapping(in_data.as_ptr(), out_ptr, in_data.len());
out_ptr = out_ptr.add(in_data.len());
}
} else {
for el in in_data.iter() {
for _ in 0..repeats {
unsafe {
*out_ptr = *el;
out_ptr = out_ptr.add(1);
}
}
}
}
}
unsafe { out_data.set_len(out_len) };
Tensor::from_data(out_shape, out_data)
}
_ => input.broadcast(out_shape).to_tensor_in(pool),
}
}
pub fn expand<T: Copy>(
pool: &BufferPool,
input: TensorView<T>,
shape: &NdTensorView<i32, 1>,
) -> Result<Tensor<T>, OpError> {
let out_shape = expand_output_shape(input.shape(), shape)?;
Ok(expand_to(pool, input, &out_shape))
}
#[derive(Debug)]
pub struct Expand {}
impl Operator for Expand {
fn name(&self) -> &str {
"Expand"
}
fn max_inputs(&self) -> Option<usize> {
Some(2)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let inputs = ctx.inputs();
let input = inputs.require(0)?;
let shape = inputs.require_as(1)?;
map_value_view!(input, x, { expand(ctx.pool(), x, &shape).into_op_result() })
}
fn can_run_in_place(&self) -> bool {
true
}
fn run_in_place(&self, input: Value, ctx: &OpRunContext) -> Result<Value, OpError> {
let shape = ctx.inputs().require_as(0)?;
let out_shape = expand_output_shape(&input.shape(), &shape)?;
if input.shape() == out_shape {
return Ok(input);
}
map_value!(input, input, {
let input = input.auto_return(ctx.pool());
let output = expand_to(ctx.pool(), input.view(), &out_shape);
Ok(output.into())
})
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(&shape_ops::Expand)
}
}
fn flattened_shape(shape: &[usize], axis: isize) -> Result<[usize; 2], OpError> {
let outer_dims = if axis == shape.len() as isize {
shape.len()
} else {
resolve_axis(shape.len(), axis)?
};
let outer_size = shape.iter().take(outer_dims).product();
let inner_size = shape.iter().skip(outer_dims).product();
Ok([outer_size, inner_size])
}
pub fn flatten<T: Copy>(
pool: &BufferPool,
input: TensorView<T>,
axis: isize,
) -> Result<Tensor<T>, OpError> {
let shape = flattened_shape(input.shape(), axis)?;
let mut output = input.to_tensor_in(pool);
output.reshape(&shape);
Ok(output)
}
pub fn flatten_in_place<T: Copy>(
pool: &BufferPool,
input: &mut Tensor<T>,
axis: isize,
) -> Result<(), OpError> {
let shape = flattened_shape(input.shape(), axis)?;
input.reshape_in(pool, &shape);
Ok(())
}
#[derive(Debug)]
pub struct Flatten {
pub axis: isize,
}
impl Operator for Flatten {
fn name(&self) -> &str {
"Flatten"
}
fn max_inputs(&self) -> Option<usize> {
Some(1)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let input = ctx.inputs().require(0)?;
map_value_view!(input, x, {
flatten(ctx.pool(), x, self.axis).into_op_result()
})
}
fn can_run_in_place(&self) -> bool {
true
}
fn run_in_place(&self, input: Value, ctx: &OpRunContext) -> Result<Value, OpError> {
map_value!(input, x, {
flatten_in_place(ctx.pool(), &mut x, self.axis)?;
Ok(x.into())
})
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(self)
}
}
impl_infer_shapes!(
Flatten,
op,
shape_ops::Flatten {
axis: op.axis as i32
}
);
fn resolve_shape(
input_shape: &[usize],
shape: &NdTensorView<i32, 1>,
allow_zero: bool,
) -> Result<SmallVec<[usize; 4]>, OpError> {
let mut unspecified_dim = None;
let mut specified_dims_size = 1;
for (dim, &size) in shape.iter().enumerate() {
if size < -1 {
return Err(OpError::InvalidValue("Invalid dimension size in shape"));
} else if size == 0 && !allow_zero {
if dim >= input_shape.len() {
return Err(OpError::InvalidValue(
"Zero dim has no corresponding input dim",
));
}
specified_dims_size *= input_shape[dim];
} else if size != -1 {
specified_dims_size *= size as usize;
} else if unspecified_dim.is_some() {
return Err(OpError::InvalidValue(
"Multiple dimensions in new shape set to -1",
));
} else {
unspecified_dim = Some(dim);
}
}
let input_len = input_shape.iter().product();
let (unspecified_dim_size, remainder) = match input_len {
0 => (0, 0),
_ => {
if specified_dims_size == 0 {
(0, 1)
} else {
(
input_len / specified_dims_size,
input_len % specified_dims_size,
)
}
}
};
if remainder != 0 || (unspecified_dim.is_none() && unspecified_dim_size > 1) {
return Err(OpError::InvalidValue(
"Input length does not match target shape",
));
}
Ok(shape
.iter()
.enumerate()
.map(|(dim, &size)| match size {
-1 => unspecified_dim_size,
0 if !allow_zero => input_shape[dim],
valid => valid as usize,
})
.collect())
}
pub fn reshape<T: Copy>(
pool: &BufferPool,
input: TensorView<T>,
shape: &NdTensorView<i32, 1>,
allow_zero: bool,
) -> Result<Tensor<T>, OpError> {
let out_shape = resolve_shape(input.shape(), shape, allow_zero)?;
let output = input.to_tensor_in(pool).into_shape(out_shape.as_slice());
Ok(output)
}
pub fn reshape_in_place<T: Copy>(
pool: &BufferPool,
input: &mut Tensor<T>,
shape: &NdTensorView<i32, 1>,
allow_zero: bool,
) -> Result<(), OpError> {
let out_shape = resolve_shape(input.shape(), shape, allow_zero)?;
input.reshape_in(pool, &out_shape);
Ok(())
}
#[derive(Debug)]
pub struct Reshape {
pub allow_zero: bool,
}
impl Operator for Reshape {
fn name(&self) -> &str {
"Reshape"
}
fn max_inputs(&self) -> Option<usize> {
Some(2)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let inputs = ctx.inputs();
let input = inputs.require(0)?;
let shape = inputs.require_as(1)?;
map_value_view!(input, x, {
reshape(ctx.pool(), x, &shape, self.allow_zero).into_op_result()
})
}
fn can_run_in_place(&self) -> bool {
true
}
fn run_in_place(&self, input: Value, ctx: &OpRunContext) -> Result<Value, OpError> {
let shape = ctx.inputs().require_as(0)?;
map_value!(input, output, {
reshape_in_place(ctx.pool(), &mut output, &shape, self.allow_zero)?;
Ok(output.into())
})
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(self)
}
}
impl_infer_shapes!(
Reshape,
op,
shape_ops::Reshape {
allow_zero: op.allow_zero
}
);
#[derive(Debug, Default)]
pub struct Shape {
pub start: Option<i32>,
pub end: Option<i32>,
}
impl Operator for Shape {
fn name(&self) -> &str {
"Shape"
}
fn max_inputs(&self) -> Option<usize> {
Some(1)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let input = ctx.inputs().require(0)?;
let shape_op = shape_ops::Shape {
start: self.start,
end: self.end,
};
let dim_range = shape_op.resolve_start_end(input.ndim());
let shape_slice = &input.shape()[dim_range];
let mut data = ctx.pool().alloc(input.ndim());
data.extend(shape_slice.iter().map(|&el| el as i32));
Tensor::from_data(&[shape_slice.len()], data).into_op_result()
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(self)
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::Fixed(ValueType::Tensor(DataType::Int32))].into())
}
}
impl_infer_shapes!(
Shape,
op,
shape_ops::Shape {
start: op.start,
end: op.end
}
);
#[derive(Debug)]
pub struct Size {}
impl Operator for Size {
fn name(&self) -> &str {
"Size"
}
fn max_inputs(&self) -> Option<usize> {
Some(1)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let input = ctx.inputs().require(0)?;
let len = input.len() as i32;
let mut output = Tensor::zeros_in(ctx.pool(), &[]);
output[[]] = len;
output.into_op_result()
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::Fixed(ValueType::Tensor(DataType::Int32))].into())
}
}
pub fn squeeze_in_place<T: Clone>(
input: &mut Tensor<T>,
axes: Option<NdTensorView<i32, 1>>,
) -> Result<(), OpError> {
let axes = axes
.map(|axes| resolve_axes(input.ndim(), axes.iter()))
.transpose()?;
let sorted_axes = if let Some(mut axes) = axes {
for &axis in axes.iter() {
if axis >= input.ndim() {
return Err(OpError::InvalidValue("Axis is invalid"));
}
if input.size(axis) != 1 {
return Err(OpError::InvalidValue(
"Can only remove dimensions of size 1",
));
}
}
axes.sort();
axes
} else {
input
.shape()
.iter()
.enumerate()
.filter_map(|(i, size)| if *size == 1 { Some(i) } else { None })
.collect()
};
for (n_removed, axis) in sorted_axes.into_iter().enumerate() {
input.remove_axis(axis - n_removed);
}
Ok(())
}
pub fn squeeze<T: Copy>(
pool: &BufferPool,
input: TensorView<T>,
axes: Option<NdTensorView<i32, 1>>,
) -> Result<Tensor<T>, OpError> {
let mut output = input.to_tensor_in(pool);
squeeze_in_place(&mut output, axes)?;
Ok(output)
}
#[derive(Debug)]
pub struct Squeeze {}
impl Operator for Squeeze {
fn name(&self) -> &str {
"Squeeze"
}
fn max_inputs(&self) -> Option<usize> {
Some(2)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let inputs = ctx.inputs();
let input = inputs.require(0)?;
let axes = inputs.get_as(1)?;
map_value_view!(input, x, { squeeze(ctx.pool(), x, axes).into_op_result() })
}
fn can_run_in_place(&self) -> bool {
true
}
fn run_in_place(&self, input: Value, ctx: &OpRunContext) -> Result<Value, OpError> {
let axes = ctx.inputs().get_as(0)?;
map_value!(input, output, {
squeeze_in_place(&mut output, axes)?;
Ok(output.into())
})
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(&shape_ops::Squeeze)
}
}
pub fn transpose<T: Copy>(
pool: &BufferPool,
input: TensorView<T>,
permutation: Option<&[usize]>,
) -> Result<Tensor<T>, OpError> {
let mut transposed = input.view();
match permutation {
Some(order) => {
if !is_valid_permutation(input.ndim(), order) {
return Err(OpError::InvalidValue("Permutation is invalid"));
}
transposed.permute(order)
}
None => {
transposed.transpose();
}
};
let output = Tensor::uninit_in(pool, transposed.shape());
Ok(output.init_from(&transposed))
}
#[derive(Debug)]
pub struct Transpose {
pub perm: Option<Vec<usize>>,
}
impl Operator for Transpose {
fn name(&self) -> &str {
"Transpose"
}
fn max_inputs(&self) -> Option<usize> {
Some(1)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let input = ctx.inputs().require(0)?;
let perm_slice = self.perm.as_deref();
map_value_view!(input, x, {
transpose(ctx.pool(), x, perm_slice).into_op_result()
})
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(self)
}
}
impl_infer_shapes!(
Transpose,
op,
shape_ops::Transpose {
perm: op.perm.as_deref(),
}
);
pub fn unsqueeze_in_place<T: Clone>(
mut input: Tensor<T>,
axes: &NdTensorView<i32, 1>,
) -> Result<Tensor<T>, OpError> {
let sorted_axes = if axes.len() == 1 {
let axis = resolve_axis(input.ndim() + 1, axes[0] as isize)?;
SmallVec::from_slice(&[axis])
} else {
let mut sorted_axes = resolve_axes(input.ndim() + axes.len(), axes.iter())?;
sorted_axes.sort_unstable();
let axes_unique = sorted_axes
.iter()
.skip(1)
.zip(sorted_axes.iter())
.all(|(prev, current)| prev != current);
if !axes_unique {
return Err(OpError::InvalidValue("Axes must be unique"));
}
sorted_axes
};
for axis in sorted_axes {
input.insert_axis(axis);
}
Ok(input)
}
pub fn unsqueeze<T: Copy>(
pool: &BufferPool,
input: TensorView<T>,
axes: &NdTensorView<i32, 1>,
) -> Result<Tensor<T>, OpError> {
unsqueeze_in_place(input.to_tensor_in(pool), axes)
}
#[derive(Debug)]
pub struct Unsqueeze {}
impl Operator for Unsqueeze {
fn name(&self) -> &str {
"Unsqueeze"
}
fn max_inputs(&self) -> Option<usize> {
Some(2)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let inputs = ctx.inputs();
let input = inputs.require(0)?;
let axes = inputs.require_as(1)?;
map_value_view!(input, x, {
unsqueeze(ctx.pool(), x, &axes).into_op_result()
})
}
fn can_run_in_place(&self) -> bool {
true
}
fn run_in_place(&self, input: Value, ctx: &OpRunContext) -> Result<Value, OpError> {
let axes = ctx.inputs().require_as(0)?;
map_value!(input, output, {
Ok(unsqueeze_in_place(output, &axes)?.into())
})
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(&rten_shape_inference::ops::Unsqueeze)
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum DimSpec {
Static(u32),
Dynamic { input: u32, dim: u32 },
}
#[derive(Debug)]
pub struct ComputeShape {
pub shape: Vec<DimSpec>,
}
impl Operator for ComputeShape {
fn name(&self) -> &str {
"ComputeShape"
}
fn max_inputs(&self) -> Option<usize> {
None
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let inputs = ctx.inputs();
let output = self
.shape
.iter()
.map(|dim| match dim {
DimSpec::Static(size) => Ok(*size as i32),
DimSpec::Dynamic {
input: input_idx,
dim,
} => {
let dim = *dim as usize;
let input = inputs.require(*input_idx as usize)?;
if input.ndim() > dim {
Ok(input.size(dim).min(i32::MAX as usize) as i32)
} else {
Err(OpError::InvalidValue(
"Dim index invalid for input tensor shape",
))
}
}
})
.collect::<Result<Vec<i32>, _>>()?;
Tensor::from(output).into_op_result()
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::Fixed(ValueType::Tensor(DataType::Int32))].into())
}
}
#[cfg(test)]
mod tests {
use std::error::Error;
use rten_bench::run_bench;
use rten_tensor::prelude::*;
use rten_tensor::rng::XorShiftRng;
use rten_tensor::test_util::expect_equal;
use rten_tensor::{NdTensor, Tensor};
use rten_testing::TestCases;
use super::{ComputeShape, DepthToSpaceMode, DimSpec, depth_to_space};
use crate::buffer_pool::BufferPool;
use crate::operator::{OpError, OperatorExt};
use crate::ops::layout::{
Reshape, Shape, Size, expand, flatten, reshape, reshape_in_place, squeeze,
squeeze_in_place, transpose, unsqueeze,
};
use crate::value::Value;
#[test]
fn test_compute_shape() {
let input_a = NdTensor::<f32, _>::zeros([2, 4, 8]);
let input_b = NdTensor::<f32, _>::zeros([24]);
let op = ComputeShape {
shape: [
DimSpec::Static(3),
DimSpec::Dynamic { input: 0, dim: 1 },
DimSpec::Static(5),
DimSpec::Dynamic { input: 1, dim: 0 },
]
.into(),
};
let result: NdTensor<i32, 1> = op.run_simple((input_a.view(), input_b.view())).unwrap();
assert_eq!(result, NdTensorView::from(&[3, 4, 5, 24]));
let op = ComputeShape {
shape: [DimSpec::Dynamic { input: 1, dim: 0 }].into(),
};
let result: Result<NdTensor<i32, 1>, _> = op.run_simple(input_a.view());
assert_eq!(result.err().unwrap(), OpError::MissingInputs);
let op = ComputeShape {
shape: [DimSpec::Dynamic { input: 0, dim: 3 }].into(),
};
let result: Result<NdTensor<i32, 1>, _> = op.run_simple(input_a.view());
assert_eq!(
result.err().unwrap(),
OpError::InvalidValue("Dim index invalid for input tensor shape")
);
}
#[test]
fn test_depth_to_space() {
#[derive(Debug)]
struct Case {
input: NdTensor<f32, 4>,
block_size: u32,
mode: DepthToSpaceMode,
expected: Result<Tensor, OpError>,
}
let input = NdTensor::from([
[[1.0]],
[[2.0]],
[[3.0]],
[[4.0]],
[[5.0]],
[[6.0]],
[[7.0]],
[[8.0]],
])
.into_shape([1, 8, 1, 1]);
let cases = [
Case {
input: input.clone(),
block_size: 2,
mode: DepthToSpaceMode::DepthColumnRow,
expected: Ok(
NdTensor::from([[[1.0, 3.0], [5.0, 7.0]], [[2.0, 4.0], [6.0, 8.0]]])
.into_shape([1, 2, 2, 2].as_slice()),
),
},
Case {
input: input.clone(),
block_size: 2,
mode: DepthToSpaceMode::ColumnRowDepth,
expected: Ok(
NdTensor::from([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]])
.into_shape([1, 2, 2, 2].as_slice()),
),
},
Case {
input: NdTensor::full([1, 16, 2, 2], 1.0),
block_size: 3,
mode: DepthToSpaceMode::ColumnRowDepth,
expected: Err(OpError::InvalidValue(
"input channels must be a multiple of `block_size` squared",
)),
},
Case {
input: NdTensor::full([1, 16, 2, 2], 1.0),
block_size: 0,
mode: DepthToSpaceMode::ColumnRowDepth,
expected: Err(OpError::InvalidValue("`block_size` must be > 0")),
},
];
cases.test_each(|case| {
let pool = BufferPool::new();
let result = depth_to_space(&pool, case.input.as_dyn(), case.block_size, case.mode);
assert_eq!(result, case.expected);
})
}
#[test]
fn test_expand() {
#[derive(Debug)]
struct Case {
input: Tensor<i32>,
shape: Vec<i32>,
expected: Result<(Vec<usize>, Option<Vec<i32>>), OpError>,
}
fn make_tensor(shape: impl AsRef<[usize]>) -> Tensor<i32> {
let len = shape.as_ref().iter().product::<usize>() as i32;
Tensor::arange(1, len + 1, None).into_shape(shape.as_ref())
}
let cases = [
Case {
input: Tensor::from(5),
shape: vec![2, 2],
expected: Ok((vec![2, 2], Some(vec![5, 5, 5, 5]))),
},
Case {
input: make_tensor([3, 1]),
shape: vec![2, 3, 1],
expected: Ok((vec![2, 3, 1], None)),
},
Case {
input: make_tensor([3, 1]),
shape: vec![2, 1, 6],
expected: Ok((vec![2, 3, 6], None)),
},
Case {
input: make_tensor([3, 1]),
shape: vec![3, 4],
expected: Ok((vec![3, 4], None)),
},
Case {
input: make_tensor([1, 2, 1]),
shape: vec![2, 2, 2],
expected: Ok((vec![2, 2, 2], Some(vec![1, 1, 2, 2, 1, 1, 2, 2]))),
},
Case {
input: make_tensor([2, 1, 2]),
shape: vec![2, 2, 2],
expected: Ok((vec![2, 2, 2], Some(vec![1, 2, 1, 2, 3, 4, 3, 4]))),
},
Case {
input: Tensor::from([1, 2, 3]),
shape: vec![2, 2],
expected: Err(OpError::IncompatibleInputShapes(
"Cannot broadcast input with target shape",
)),
},
];
cases.test_each(|case| {
let pool = BufferPool::new();
let shape = NdTensorView::from(case.shape.as_slice());
let result = expand(&pool, case.input.view(), &shape);
match (&result, &case.expected) {
(Ok(output), Ok((expected_shape, expected_data))) => {
assert_eq!(output.shape(), expected_shape.as_slice());
if let Some(data) = expected_data {
assert_eq!(output.data().unwrap(), data);
}
}
(output, expected) => assert_eq!(output.as_ref().err(), expected.as_ref().err()),
}
})
}
#[test]
fn test_flatten() {
#[derive(Debug)]
struct Case {
shape: Vec<usize>,
axis: isize,
expected: Result<Vec<usize>, OpError>,
}
let cases = [
Case {
shape: [1, 5, 1, 1].into(),
axis: 1,
expected: Ok([1, 5].into()),
},
Case {
shape: [2, 3, 1, 4].into(),
axis: 2,
expected: Ok([6, 4].into()),
},
Case {
shape: [2, 3, 1, 4].into(),
axis: 0,
expected: Ok([1, 24].into()),
},
Case {
shape: [2, 2].into(),
axis: 2,
expected: Ok([4, 1].into()),
},
Case {
shape: [2, 3, 4].into(),
axis: -1,
expected: Ok([6, 4].into()),
},
Case {
shape: [2, 3, 4].into(),
axis: -2,
expected: Ok([2, 12].into()),
},
Case {
shape: [2, 3, 4].into(),
axis: -3,
expected: Ok([1, 24].into()),
},
Case {
shape: [2, 3, 4].into(),
axis: 4,
expected: Err(OpError::InvalidValue("Axis is invalid")),
},
Case {
shape: [2, 3, 4].into(),
axis: -4,
expected: Err(OpError::InvalidValue("Axis is invalid")),
},
];
cases.test_each(|case| {
let pool = BufferPool::new();
let input = Tensor::<f32>::zeros(case.shape.as_slice());
let result =
flatten(&pool, input.view(), case.axis).map(|tensor| tensor.shape().to_vec());
assert_eq!(result, case.expected);
})
}
#[test]
fn test_reshape() {
#[derive(Debug)]
struct Case<'a> {
input: &'a [usize],
shape: &'a [i32],
allow_zero: bool,
expected: Result<&'a [usize], OpError>,
}
let cases = [
Case {
input: &[2, 8],
shape: &[16],
allow_zero: false,
expected: Ok(&[16]),
},
Case {
input: &[2, 8],
shape: &[0, 0],
allow_zero: false,
expected: Ok(&[2, 8]),
},
Case {
input: &[0, 0],
shape: &[0, 0],
allow_zero: false,
expected: Ok(&[0, 0]),
},
Case {
input: &[2, 8],
shape: &[4, -1],
allow_zero: false,
expected: Ok(&[4, 4]),
},
Case {
input: &[2, 0],
shape: &[4, -1],
allow_zero: false,
expected: Ok(&[4, 0]),
},
Case {
input: &[8],
shape: &[9],
allow_zero: false,
expected: Err(OpError::InvalidValue(
"Input length does not match target shape",
)),
},
Case {
input: &[2],
shape: &[2, 0],
allow_zero: false,
expected: Err(OpError::InvalidValue(
"Zero dim has no corresponding input dim",
)),
},
Case {
input: &[0, 0, 10],
shape: &[10, 0, 0],
allow_zero: true,
expected: Ok(&[10, 0, 0]),
},
Case {
input: &[10, 1],
shape: &[10, 0],
allow_zero: true,
expected: Err(OpError::InvalidValue(
"Input length does not match target shape",
)),
},
Case {
input: &[2, 2],
shape: &[-1, -1],
allow_zero: false,
expected: Err(OpError::InvalidValue(
"Multiple dimensions in new shape set to -1",
)),
},
Case {
input: &[2, 8],
shape: &[1, 8],
allow_zero: false,
expected: Err(OpError::InvalidValue(
"Input length does not match target shape",
)),
},
];
for case in cases {
let pool = BufferPool::new();
let input = Tensor::<f32>::zeros(case.input);
let shape = NdTensorView::from(case.shape);
let result = reshape(&pool, input.view(), &shape.view(), case.allow_zero);
let shape = result.as_ref().map(|t| t.shape());
assert_eq!(shape, case.expected.as_deref());
}
}
#[test]
fn test_reshape_in_place() {
let pool = BufferPool::new();
let mut input = Tensor::from_data(&[2, 2], vec![-0.5, 0.5, 3.0, -5.5]);
let shape = NdTensor::from([4]);
let expected = input.to_shape([4].as_slice());
reshape_in_place(
&pool,
&mut input,
&shape.view(),
false,
)
.unwrap();
assert_eq!(&input, &expected);
}
#[test]
fn test_reshape_op() -> Result<(), Box<dyn Error>> {
let input = Tensor::from_data(&[2, 2], vec![-0.5, 0.5, 3.0, -5.5]);
let shape = Tensor::from([4]);
let expected = input.to_shape([4].as_slice());
let op = Reshape { allow_zero: false };
let result: Tensor<f32> = op.run_simple((&input, &shape))?;
expect_equal(&result, &expected)?;
Ok(())
}
#[test]
fn test_shape() {
#[derive(Debug)]
struct Case {
input: Value,
op: Shape,
expected: Vec<i32>,
}
let cases = [
Case {
input: Tensor::from_data(&[1, 1, 2, 2], vec![1.0, 2.0, 3.0, 4.0]).into(),
op: Shape::default(),
expected: [1, 1, 2, 2].into(),
},
Case {
input: Tensor::<i32>::zeros(&[1, 2, 3, 4]).into(),
op: Shape::default(),
expected: [1, 2, 3, 4].into(),
},
Case {
input: Tensor::<i32>::zeros(&[1, 2, 3, 4]).into(),
op: Shape {
start: Some(1),
end: Some(3),
},
expected: [2, 3].into(),
},
Case {
input: Tensor::<i32>::zeros(&[1, 2, 3, 4]).into(),
op: Shape {
start: Some(-3),
end: Some(-1),
},
expected: [2, 3].into(),
},
Case {
input: Tensor::<i32>::zeros(&[1, 2, 3, 4]).into(),
op: Shape {
start: Some(-6),
end: Some(7),
},
expected: [1, 2, 3, 4].into(),
},
Case {
input: Tensor::<i32>::zeros(&[1, 2, 3, 4]).into(),
op: Shape {
start: Some(2),
end: Some(1),
},
expected: [].into(),
},
Case {
input: Tensor::from(1i32).into(),
op: Shape::default(),
expected: [].into(),
},
];
cases.test_each(|case| {
let result: Tensor<i32> = case.op.run_simple(&case.input).unwrap();
assert_eq!(result.shape(), &[case.expected.len()]);
assert_eq!(result.to_vec(), case.expected);
});
}
#[test]
fn test_size() {
let op = Size {};
let input = Tensor::from([[1, 2], [3, 4]]);
let result: Tensor<i32> = op.run_simple(&input).unwrap();
assert_eq!(result.ndim(), 0);
assert_eq!(result.item(), Some(&4));
}
#[test]
fn test_squeeze() -> Result<(), Box<dyn Error>> {
let pool = BufferPool::new();
let mut rng = XorShiftRng::new(5678);
let input = Tensor::<f32>::rand(&[1, 5, 5, 1], &mut rng);
let mut expected = input.clone();
expected.reshape(&[5, 5]);
let result = squeeze(&pool, input.view(), None).unwrap();
expect_equal(&result, &expected)?;
expected.reshape(&[1, 5, 5]);
let result = squeeze(&pool, input.view(), Some(NdTensor::from([3]).view())).unwrap();
expect_equal(&result, &expected)?;
expected.reshape(&[5, 5, 1]);
let result = squeeze(&pool, input.view(), Some(NdTensor::from([0]).view())).unwrap();
expect_equal(&result, &expected)?;
Ok(())
}
#[test]
fn test_squeeze_in_place() -> Result<(), Box<dyn Error>> {
let mut rng = XorShiftRng::new(5678);
let mut input = Tensor::<f32>::rand(&[1, 1, 5, 5], &mut rng);
let expected = input.clone().into_shape([5, 5].as_slice());
squeeze_in_place(&mut input, None).unwrap();
expect_equal(&input, &expected)?;
let mut input = Tensor::<f32>::rand(&[1, 5, 2, 5], &mut rng);
input.permute(&[3, 2, 1, 0]);
let expected = input.clone().into_shape([5, 2, 5].as_slice());
squeeze_in_place(&mut input, None).unwrap();
expect_equal(&input, &expected)?;
Ok(())
}
#[test]
fn test_squeeze_invalid_inputs() {
let mut rng = XorShiftRng::new(5678);
let input = Tensor::<f32>::rand(&[1, 5, 5, 1], &mut rng);
let pool = BufferPool::new();
let result = squeeze(&pool, input.view(), Some(NdTensor::from([1]).view()));
assert_eq!(
result.err(),
Some(OpError::InvalidValue(
"Can only remove dimensions of size 1"
))
);
}
#[test]
fn test_transpose() -> Result<(), Box<dyn Error>> {
let pool = BufferPool::new();
let mut rng = XorShiftRng::new(5678);
let input = Tensor::<f32>::rand(&[10, 20], &mut rng);
let mut reversed = input.clone();
reversed.permute(&[1, 0]);
let result = transpose(&pool, input.view(), None).unwrap();
expect_equal(&result, &reversed)?;
let result = transpose(&pool, input.view(), Some(&[0, 1])).unwrap();
expect_equal(&result, &input)?;
let result = transpose(&pool, input.view(), Some(&[1, 0])).unwrap();
expect_equal(&result, &reversed)?;
Ok(())
}
#[test]
fn test_transpose_invalid_inputs() {
let pool = BufferPool::new();
let mut rng = XorShiftRng::new(5678);
let input = Tensor::<f32>::rand(&[10, 20], &mut rng);
let result = transpose(&pool, input.view(), Some(&[0, 1, 1]));
assert_eq!(
result.err(),
Some(OpError::InvalidValue("Permutation is invalid"))
);
let result = transpose(&pool, input.view(), Some(&[]));
assert_eq!(
result.err(),
Some(OpError::InvalidValue("Permutation is invalid"))
);
let result = transpose(&pool, input.view(), Some(&[2, 1]));
assert_eq!(
result.err(),
Some(OpError::InvalidValue("Permutation is invalid"))
);
let result = transpose(&pool, input.view(), Some(&[1, 1]));
assert_eq!(
result.err(),
Some(OpError::InvalidValue("Permutation is invalid"))
);
}
#[test]
fn test_unsqueeze() {
let pool = BufferPool::new();
let mut rng = XorShiftRng::new(5678);
let input = Tensor::<f32>::rand(&[3, 4, 5], &mut rng);
let output = unsqueeze(&pool, input.view(), &NdTensor::from([0, 4]).view()).unwrap();
assert_eq!(output.shape(), &[1, 3, 4, 5, 1]);
let output = unsqueeze(&pool, input.view(), &NdTensor::from([4, 0]).view()).unwrap();
assert_eq!(output.shape(), &[1, 3, 4, 5, 1]);
let scalar = Tensor::from(2.0);
let output = unsqueeze(&pool, scalar.view(), &NdTensor::from([0]).view()).unwrap();
assert_eq!(output.shape(), &[1]);
assert_eq!(output.to_vec(), &[2.0]);
}
#[test]
fn test_unsqueeze_invalid_inputs() {
let pool = BufferPool::new();
let mut rng = XorShiftRng::new(5678);
let input = Tensor::<f32>::rand(&[10, 20], &mut rng);
let result = unsqueeze(&pool, input.view(), &NdTensor::from([3]).view());
assert_eq!(result.err(), Some(OpError::InvalidValue("Axis is invalid")));
let result = unsqueeze(&pool, input.view(), &NdTensor::from([1, 1]).view());
assert_eq!(
result.err(),
Some(OpError::InvalidValue("Axes must be unique"))
);
}
use rten_tensor::{NdTensorView, TensorView, TensorViewMut};
fn reference_transpose_into<'a, T: Clone>(src: TensorView<T>, mut dest: TensorViewMut<T>) {
let mut src = src.clone();
src.merge_axes();
while src.ndim() < 4 {
src.insert_axis(0);
}
let dest_data = dest.data_mut().unwrap();
let src: NdTensorView<T, 4> = src.nd_view();
let mut dest_offset = 0;
for i0 in 0..src.size(0) {
for i1 in 0..src.size(1) {
for i2 in 0..src.size(2) {
for i3 in 0..src.size(3) {
unsafe {
let elt = src.get_unchecked([i0, i1, i2, i3]).clone();
*dest_data.get_unchecked_mut(dest_offset) = elt;
dest_offset += 1;
}
}
}
}
}
}
#[test]
#[ignore]
fn bench_transpose() {
let mut rng = XorShiftRng::new(1234);
struct Case<'a> {
shape: &'a [usize],
perm: &'a [usize],
}
let cases = [
Case {
shape: &[512, 512],
perm: &[0, 1],
},
Case {
shape: &[128, 128],
perm: &[1, 0],
},
Case {
shape: &[256, 256],
perm: &[1, 0],
},
Case {
shape: &[512, 512],
perm: &[1, 0],
},
Case {
shape: &[1024, 1024],
perm: &[1, 0],
},
Case {
shape: &[127, 127],
perm: &[1, 0],
},
Case {
shape: &[255, 255],
perm: &[1, 0],
},
Case {
shape: &[513, 513],
perm: &[1, 0],
},
Case {
shape: &[1023, 1023],
perm: &[1, 0],
},
Case {
shape: &[4, 1500, 8, 64],
perm: &[0, 2, 1, 3],
},
Case {
shape: &[4, 8, 1500, 64],
perm: &[0, 2, 1, 3],
},
Case {
shape: &[1, 1500, 8, 64],
perm: &[0, 2, 3, 1],
},
Case {
shape: &[1, 288, 8, 64],
perm: &[0, 2, 1, 3],
},
];
for Case { shape, perm } in cases {
let tensor = Tensor::<f32>::rand(shape, &mut rng);
let mut dest = Tensor::zeros(shape);
let copy_stats = run_bench(100, None, || {
dest.copy_from(&tensor.view());
});
assert_eq!(dest, tensor);
let reference_transpose_stats = run_bench(100, None, || {
let transposed = tensor.permuted(perm);
reference_transpose_into(
transposed.view(),
dest.reshaped_mut(transposed.shape()).unwrap(),
);
});
let transpose_stats = run_bench(100, None, || {
let transposed = tensor.permuted(perm);
dest.reshape(transposed.shape());
dest.copy_from(&transposed);
});
assert_eq!(dest, tensor.permuted(perm));
let transpose_overhead =
(transpose_stats.mean - copy_stats.mean).max(0.) / copy_stats.mean;
println!(
"transpose shape {:?} perm {:?} copy {:.3}ms ref transpose {:.3}ms opt transpose {:.3}ms overhead {}",
shape,
perm,
copy_stats.median,
reference_transpose_stats.median,
transpose_stats.median,
transpose_overhead
);
}
}
}