use std::mem::MaybeUninit;
use rten_shape_inference::ops as shape_ops;
use rten_tensor::prelude::*;
use rten_tensor::{AssumeInit, NdTensorView, Tensor, TensorView};
use smallvec::SmallVec;
use crate::buffer_pool::{AutoReturn, BufferPool};
use crate::infer_shapes::{InferShapes, impl_infer_shapes};
use crate::operator::{
InputList, IntoOpResult, OpError, OpRunContext, Operator, OutputList, OutputType,
OutputTypeList, OutputTypesContext,
};
use crate::ops::{map_value, map_value_view, resolve_axis};
use crate::value::{TryFromValueError, Value, ValueView};
fn concatenated_shape<T: Copy>(
first_shape: &[usize],
inputs: &[TensorView<T>],
axis: usize,
) -> Result<SmallVec<[usize; 4]>, OpError> {
let mut out_shape = SmallVec::from_slice(first_shape);
for other in inputs {
let other_shape = other.shape();
if other_shape.len() != first_shape.len() {
return Err(OpError::IncompatibleInputShapes(
"Tensors must have the same number of dimensions",
));
}
for (d, (first_size, other_size)) in first_shape.iter().zip(other_shape.iter()).enumerate()
{
if d != axis && first_size != other_size {
return Err(OpError::IncompatibleInputShapes(
"Dimensions must be the same except for concat axis",
));
} else if d == axis {
out_shape[axis] += other_size;
}
}
}
Ok(out_shape)
}
fn typed_inputs<'a, T>(
inputs: &InputList<'a>,
_: TensorView<T>,
) -> Result<SmallVec<[TensorView<'a, T>; 4]>, OpError>
where
TensorView<'a, T>: TryFrom<ValueView<'a>, Error = TryFromValueError>,
{
let mut typed_inputs: SmallVec<_> = SmallVec::with_capacity(inputs.len());
for input in inputs.iter().flatten() {
typed_inputs.push(input.try_into()?);
}
Ok(typed_inputs)
}
fn concat_impl<T: Copy>(
pool: &BufferPool,
out_shape: &[usize],
axis: usize,
first_input: &TensorView<T>,
inputs: &[TensorView<T>],
) -> Result<Tensor<T>, OpError> {
let mut output = Tensor::with_capacity_in(pool, out_shape, axis);
for input in std::iter::once(first_input).chain(inputs) {
output.append(axis, input).expect("should have capacity");
}
Ok(output)
}
pub fn concat<T: Copy>(
pool: &BufferPool,
inputs: &[TensorView<T>],
axis: isize,
) -> Result<Tensor<T>, OpError> {
let axis = resolve_axis(inputs[0].ndim(), axis)?;
let out_shape = concatenated_shape(inputs[0].shape(), &inputs[1..], axis)?;
concat_impl(pool, &out_shape, axis, &inputs[0], &inputs[1..])
}
pub fn concat_in_place<T: Copy>(
pool: &BufferPool,
mut output: Tensor<T>,
inputs: &[TensorView<T>],
axis: isize,
) -> Result<Tensor<T>, OpError> {
let axis = resolve_axis(output.ndim(), axis)?;
let out_shape = concatenated_shape(output.shape(), inputs, axis)?;
if !output.has_capacity(axis, out_shape[axis]) {
let output = output.auto_return(pool);
return concat_impl(pool, &out_shape, axis, &output.view(), inputs);
}
for input in inputs {
output.append(axis, input).expect("should have capacity");
}
Ok(output)
}
#[derive(Debug)]
pub struct Concat {
pub axis: isize,
}
impl Operator for Concat {
fn name(&self) -> &str {
"Concat"
}
fn max_inputs(&self) -> Option<usize> {
None
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let inputs = ctx.inputs();
let first = inputs.require(0)?;
map_value_view!(first, first, [FloatTensor, Int32Tensor], {
let typed_inputs = typed_inputs(inputs, first)?;
concat(ctx.pool(), &typed_inputs, self.axis).into_op_result()
})
}
fn can_run_in_place(&self) -> bool {
true
}
fn run_in_place(&self, input: Value, ctx: &OpRunContext) -> Result<Value, OpError> {
map_value!(input, input, [FloatTensor, Int32Tensor], {
let typed_inputs = typed_inputs(ctx.inputs(), input.view())?;
concat_in_place(ctx.pool(), input, &typed_inputs, self.axis).map(|t| t.into())
})
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(self)
}
}
impl_infer_shapes!(
Concat,
op,
shape_ops::Concat {
axis: op.axis as i32
}
);
fn write_slice<'a, T>(dest: &'a mut [MaybeUninit<T>], src: &[T]) -> &'a mut [T]
where
T: Copy,
{
let uninit_src: &[MaybeUninit<T>] = unsafe { std::mem::transmute(src) };
dest.copy_from_slice(uninit_src);
unsafe { dest.assume_init() }
}
fn tile_inner<T: Copy>(
input: &[T],
output: &mut [MaybeUninit<T>],
input_shape: &[usize],
repeats: &[usize],
) {
if repeats.iter().all(|n| *n == 1) {
write_slice(output, input);
return;
}
let mut n_init = 0;
match (input_shape, repeats) {
([size], [repeats]) => {
assert!(input.len() == *size);
assert!(input.len() * repeats == output.len());
for out_chunk in output.chunks_mut(input.len()) {
n_init += write_slice(out_chunk, input).len();
}
}
([size, inner_size @ ..], [repeats, inner_repeats @ ..]) => {
assert!(output.len().is_multiple_of(*repeats));
let out_chunk_len = output.len() / repeats;
let inner_input_len = input.len() / size;
let inner_output_len = out_chunk_len / size;
for out_chunk in output.chunks_mut(out_chunk_len) {
for (inner_input, inner_output) in input
.chunks(inner_input_len)
.zip(out_chunk.chunks_mut(inner_output_len))
{
tile_inner(inner_input, inner_output, inner_size, inner_repeats);
n_init += inner_output.len();
}
}
}
([], []) => {
n_init += write_slice(output, input).len();
}
_ => panic!("input_shape.len() != repeats.len()"),
}
assert!(n_init == output.len());
}
pub fn tile<T: Copy>(
pool: &BufferPool,
input: TensorView<T>,
repeats: NdTensorView<i32, 1>,
) -> Result<Tensor<T>, OpError> {
if repeats.size(0) != input.ndim() || repeats.iter().any(|n| *n < 0) {
return Err(OpError::InvalidValue("invalid repeats"));
}
let repeats: Vec<usize> = repeats.iter().map(|r| *r as usize).collect();
let out_shape: Vec<_> = input
.shape()
.iter()
.zip(repeats.iter())
.map(|(size, repeat)| size * repeat)
.collect();
let mut output = Tensor::uninit_in(pool, &out_shape);
if !output.is_empty() {
tile_inner(
input.to_contiguous_in(pool).auto_return(pool).data(),
output.data_mut().unwrap(),
input.shape(),
&repeats,
);
}
let output = unsafe { output.assume_init() };
Ok(output)
}
#[derive(Debug)]
pub struct Tile {}
impl Operator for Tile {
fn name(&self) -> &str {
"Tile"
}
fn max_inputs(&self) -> Option<usize> {
Some(2)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let inputs = ctx.inputs();
let input = inputs.require(0)?;
let repeats = inputs.require_as(1)?;
map_value_view!(input, input, [FloatTensor, Int32Tensor], {
tile(ctx.pool(), input, repeats).into_op_result()
})
}
fn can_run_in_place(&self) -> bool {
true
}
fn run_in_place(&self, input: Value, ctx: &OpRunContext) -> Result<Value, OpError> {
let repeats: NdTensorView<i32, 1> = ctx.inputs().require_as(0)?;
if repeats.iter().all(|n| *n == 1) {
return Ok(input);
}
map_value!(input, input, [FloatTensor, Int32Tensor], {
tile(ctx.pool(), input.view(), repeats).map(|t| t.into())
})
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
}
#[cfg(test)]
mod tests {
use std::error::Error;
use rten_tensor::Tensor;
use rten_tensor::prelude::*;
use rten_tensor::test_util::expect_equal;
use rten_testing::TestCases;
use crate::buffer_pool::BufferPool;
use crate::ops::OpError;
use super::{concat, concat_in_place, tile};
fn from_slice<T: Clone>(data: &[T]) -> Tensor<T> {
Tensor::from_data(&[data.len()], data.to_vec())
}
#[test]
fn test_concat() -> Result<(), Box<dyn Error>> {
let pool = BufferPool::new();
let a = Tensor::from_data(&[2, 2, 1], vec![0.1, 0.2, 0.3, 0.4]);
let b = Tensor::from_data(&[2, 2, 1], vec![1.0, 2.0, 3.0, 4.0]);
let expected = Tensor::from_data(&[4, 2, 1], vec![0.1, 0.2, 0.3, 0.4, 1.0, 2.0, 3.0, 4.0]);
let result = concat(&pool, &[a.view(), b.view()], 0).unwrap();
expect_equal(&result, &expected)?;
let expected = Tensor::from_data(&[2, 2, 2], vec![0.1, 1.0, 0.2, 2.0, 0.3, 3.0, 0.4, 4.0]);
let result = concat(&pool, &[a.view(), b.view()], 2).unwrap();
expect_equal(&result, &expected)?;
let result = concat(&pool, &[a.view()], 0).unwrap();
expect_equal(&result, &a)?;
let result = concat(&pool, &[a.view(), b.view(), a.view()], 0).unwrap();
assert_eq!(result.shape(), &[6, 2, 1]);
let a = from_slice(&[1, 2, 3]);
let b = from_slice(&[]);
let c = from_slice(&[4, 5, 6]);
let result = concat(&pool, &[a.view(), b.view(), c.view()], 0).unwrap();
assert_eq!(result.shape(), &[6]);
assert_eq!(result.to_vec(), &[1, 2, 3, 4, 5, 6]);
Ok(())
}
#[test]
fn test_concat_in_place() -> Result<(), Box<dyn Error>> {
let pool = BufferPool::new();
let dest = Tensor::with_capacity_in(&pool, &[3, 3], 1);
let dest =
concat_in_place(&pool, dest, &[Tensor::from([[1], [2], [3]]).view()], 1).unwrap();
let dest =
concat_in_place(&pool, dest, &[Tensor::from([[4], [5], [6]]).view()], 1).unwrap();
let dest =
concat_in_place(&pool, dest, &[Tensor::from([[7], [8], [9]]).view()], 1).unwrap();
assert_eq!(dest.shape(), &[3, 3]);
assert_eq!(dest, Tensor::from([[1, 4, 7], [2, 5, 8], [3, 6, 9],]));
let dest =
concat_in_place(&pool, dest, &[Tensor::from([[10], [11], [12]]).view()], 1).unwrap();
assert_eq!(dest.shape(), &[3, 4]);
assert_eq!(
dest,
Tensor::from([[1, 4, 7, 10], [2, 5, 8, 11], [3, 6, 9, 12],])
);
let result = concat_in_place(&pool, dest.clone(), &[Tensor::from([[1, 2, 3]]).view()], 1);
assert!(result.is_err());
let result = concat_in_place(&pool, dest.clone(), &[Tensor::from(1).view()], 1);
assert!(result.is_err());
Ok(())
}
#[test]
fn test_concat_invalid_inputs() {
let pool = BufferPool::new();
let input = from_slice(&[1, 2, 3]);
let result = concat(&pool, &[input.view(), input.view()], 1);
assert_eq!(result.err(), Some(OpError::InvalidValue("Axis is invalid")));
let a = Tensor::<f32>::zeros(&[1]);
let b = Tensor::<f32>::zeros(&[1, 2]);
let result = concat(&pool, &[a.view(), b.view()], 0);
assert_eq!(
result.err(),
Some(OpError::IncompatibleInputShapes(
"Tensors must have the same number of dimensions"
))
);
let a = Tensor::<f32>::zeros(&[5, 10]);
let b = Tensor::<f32>::zeros(&[5, 11]);
let result = concat(&pool, &[a.view(), b.view()], 0);
assert_eq!(
result.err(),
Some(OpError::IncompatibleInputShapes(
"Dimensions must be the same except for concat axis"
))
);
}
#[test]
fn test_tile() {
#[derive(Debug)]
struct Case {
input: Tensor<i32>,
repeats: Tensor<i32>,
expected: Tensor<i32>,
}
let cases = [
Case {
input: Tensor::<i32>::zeros(&[3, 4, 5]),
repeats: Tensor::from([4, 0, 1]),
expected: Tensor::<i32>::zeros(&[12, 0, 5]),
},
Case {
input: Tensor::from(5),
repeats: Tensor::from([] as [i32; 0]),
expected: Tensor::from(5),
},
Case {
input: Tensor::from([1, 2, 3, 4]),
repeats: Tensor::from([3]),
expected: Tensor::from([1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4]),
},
Case {
input: Tensor::from([[3]]),
repeats: Tensor::from([3, 2]),
expected: Tensor::from([[3, 3], [3, 3], [3, 3]]),
},
Case {
input: Tensor::from([1, 2, 3, 4]),
repeats: Tensor::from([1]),
expected: Tensor::from([1, 2, 3, 4]),
},
Case {
input: Tensor::from([[1, 2], [3, 4]]),
repeats: Tensor::from([1, 2]),
expected: Tensor::from([[1, 2, 1, 2], [3, 4, 3, 4]]),
},
Case {
input: Tensor::from([[1, 2], [3, 4]]),
repeats: Tensor::from([2, 1]),
expected: Tensor::from([[1, 2], [3, 4], [1, 2], [3, 4]]),
},
Case {
input: Tensor::from([[1, 2], [3, 4]]),
repeats: Tensor::from([2, 2]),
expected: Tensor::from([[1, 2, 1, 2], [3, 4, 3, 4], [1, 2, 1, 2], [3, 4, 3, 4]]),
},
];
cases.test_each(|case| {
let pool = BufferPool::new();
let Case {
input,
repeats,
expected,
} = case;
let result = tile(&pool, input.view(), repeats.nd_view()).unwrap();
expect_equal(&result, &expected).unwrap();
});
}
#[test]
fn test_tile_invalid_repeats() {
#[derive(Debug)]
struct Case {
input: Tensor<i32>,
repeats: Tensor<i32>,
expected_error: OpError,
}
let cases = [
Case {
input: Tensor::from([1, 2, 3]),
repeats: Tensor::from([1, 2]),
expected_error: OpError::InvalidValue("invalid repeats"),
},
Case {
input: Tensor::from([1, 2, 3]),
repeats: Tensor::from([-1]),
expected_error: OpError::InvalidValue("invalid repeats"),
},
];
cases.test_each(|case| {
let pool = BufferPool::new();
let Case {
input,
repeats,
expected_error,
} = case;
let result = tile(&pool, input.view(), repeats.nd_view());
assert_eq!(result.err().as_ref(), Some(expected_error));
});
}
}