use rten_base::byte_cast::{Pod, cast_pod_vec};
use rten_base::num;
use rten_tensor::Tensor;
use rten_tensor::prelude::*;
use crate::buffer_pool::BufferPool;
use crate::infer_shapes::{InferShapes, InferShapesError, SymTensor, SymbolGen, UnaryOp};
use crate::operator::{
IntoOpResult, OpError, OpRunContext, Operator, OutputList, OutputType, OutputTypeList,
OutputTypesContext,
};
use crate::value::{DataType, Value, ValueType, ValueView};
fn cast(pool: &BufferPool, input: ValueView, dtype: DataType) -> Result<Value, OpError> {
macro_rules! cast_as {
($x:ident) => {
Ok($x.to_tensor_in(pool).into())
};
($x:ident, $dest_ty:ty) => {
Ok($x.map_in(pool, |x| *x as $dest_ty).into())
};
}
match dtype {
DataType::Int32 => match input {
ValueView::Int32Tensor(t) => cast_as!(t),
ValueView::FloatTensor(t) => cast_as!(t, i32),
ValueView::Int8Tensor(t) => cast_as!(t, i32),
ValueView::UInt8Tensor(t) => cast_as!(t, i32),
ValueView::Sequence(_) => Err(OpError::UnsupportedType),
},
DataType::Float => match input {
ValueView::FloatTensor(t) => cast_as!(t),
ValueView::Int32Tensor(t) => cast_as!(t, f32),
ValueView::Int8Tensor(t) => cast_as!(t, f32),
ValueView::UInt8Tensor(t) => cast_as!(t, f32),
ValueView::Sequence(_) => Err(OpError::UnsupportedType),
},
DataType::Int8 => match input {
ValueView::Int8Tensor(t) => cast_as!(t),
ValueView::FloatTensor(t) => cast_as!(t, i8),
ValueView::Int32Tensor(t) => cast_as!(t, i8),
ValueView::UInt8Tensor(t) => cast_as!(t, i8),
ValueView::Sequence(_) => Err(OpError::UnsupportedType),
},
DataType::UInt8 => match input {
ValueView::UInt8Tensor(t) => cast_as!(t),
ValueView::FloatTensor(t) => cast_as!(t, u8),
ValueView::Int32Tensor(t) => cast_as!(t, u8),
ValueView::Int8Tensor(t) => cast_as!(t, u8),
ValueView::Sequence(_) => Err(OpError::UnsupportedType),
},
}
}
fn cast_tensor<T, U>(mut data: Tensor<T>) -> Tensor<U>
where
T: Pod + num::Cast<U>,
U: Pod<Bytes = T::Bytes>,
{
data.apply(|x| num::Cast::<U>::cast(*x).cast_bytes());
let shape = data.shape().to_vec();
let data = cast_pod_vec::<T, U>(data.into_data()).unwrap();
Tensor::from_data(&shape, data)
}
fn cast_in_place(input: Value, dtype: DataType) -> Result<Value, Value> {
match dtype {
DataType::Int32 => match input {
Value::Int32Tensor(t) => Ok(t.into()),
Value::FloatTensor(t) => Ok(cast_tensor::<_, i32>(t).into()),
_ => Err(input),
},
DataType::Float => match input {
Value::FloatTensor(t) => Ok(t.into()),
Value::Int32Tensor(t) => Ok(cast_tensor::<_, f32>(t).into()),
_ => Err(input),
},
DataType::Int8 => match input {
Value::Int8Tensor(t) => Ok(t.into()),
Value::UInt8Tensor(t) => Ok(cast_tensor::<_, i8>(t).into()),
_ => Err(input),
},
DataType::UInt8 => match input {
Value::UInt8Tensor(t) => Ok(t.into()),
Value::Int8Tensor(t) => Ok(cast_tensor::<_, u8>(t).into()),
_ => Err(input),
},
}
}
#[derive(Debug)]
pub struct Cast {
pub to: DataType,
}
impl Operator for Cast {
fn name(&self) -> &str {
"Cast"
}
fn max_inputs(&self) -> Option<usize> {
Some(1)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let input = ctx.inputs().require(0)?;
cast(ctx.pool(), input, self.to).into_op_result()
}
fn can_run_in_place(&self) -> bool {
true
}
fn run_in_place(&self, input: Value, ctx: &OpRunContext) -> Result<Value, OpError> {
match cast_in_place(input, self.to) {
Ok(output) => Ok(output),
Err(input) => {
let converted = cast(ctx.pool(), input.as_view(), self.to)?;
input.add_to_pool(ctx.pool());
Ok(converted)
}
}
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(self)
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::Fixed(ValueType::Tensor(self.to))].into())
}
}
impl InferShapes for Cast {
fn infer_shapes(
&self,
inputs: &[SymTensor],
_sym_gen: &mut SymbolGen,
) -> Result<Vec<SymTensor>, InferShapesError> {
let Some(data) = inputs.first() else {
return Err(InferShapesError::IncorrectInputCount);
};
let value = if data.values().is_some() && self.to == DataType::Int32 {
data.clone()
} else if let Some(shape) = data.shape() {
SymTensor::from_shape(shape.collect())
} else {
SymTensor::unknown("unknown input shape")
};
Ok([value].into())
}
}
#[derive(Debug)]
pub struct CastLike {}
impl Operator for CastLike {
fn name(&self) -> &str {
"CastLike"
}
fn max_inputs(&self) -> Option<usize> {
Some(2)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let target_type = ctx.inputs().require(1)?;
let ValueType::Tensor(to) = target_type.dtype() else {
return Err(OpError::InvalidValue("expected target_type to be a tensor"));
};
Cast { to }.run(ctx)
}
fn can_run_in_place(&self) -> bool {
true
}
fn run_in_place(&self, input: Value, ctx: &OpRunContext) -> Result<Value, OpError> {
let target_type = ctx.inputs().require(0)?;
let ValueType::Tensor(to) = target_type.dtype() else {
return Err(OpError::InvalidValue("expected target_type to be a tensor"));
};
Cast { to }.run_in_place(input, ctx)
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(1)].into())
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(&UnaryOp)
}
}
#[cfg(test)]
mod tests {
use rten_tensor::Tensor;
use rten_testing::TestCases;
use super::{Cast, CastLike};
use crate::operator::{InputList, OperatorExt};
use crate::value::{DataType, Value, ValueType};
#[test]
fn test_cast() {
#[derive(Debug)]
struct Case {
input: Value,
dtype: DataType,
expected: Value,
}
let cases = [
Case {
input: Tensor::from([1, 2, 3]).into(),
dtype: DataType::Float,
expected: Tensor::from([1., 2., 3.]).into(),
},
Case {
input: Tensor::from([1, 2, 3]).into(),
dtype: DataType::Int32,
expected: Tensor::from([1, 2, 3]).into(),
},
Case {
input: Tensor::from([i8::MIN as i32, 0, i8::MAX as i32]).into(),
dtype: DataType::Int8,
expected: Tensor::from([i8::MIN, 0, i8::MAX]).into(),
},
Case {
input: Tensor::from([u8::MIN as i32, 0, u8::MAX as i32]).into(),
dtype: DataType::UInt8,
expected: Tensor::from([u8::MIN, 0, u8::MAX]).into(),
},
Case {
input: Tensor::from([1., 2., 3.]).into(),
dtype: DataType::Int32,
expected: Tensor::from([1, 2, 3]).into(),
},
Case {
input: Tensor::from([1., 2., 3.]).into(),
dtype: DataType::Float,
expected: Tensor::from([1., 2., 3.]).into(),
},
Case {
input: Tensor::from([i32::MIN, i32::MAX]).into(),
dtype: DataType::Float,
expected: Tensor::from([-2147483600.0, 2147483600.0]).into(),
},
Case {
input: Tensor::from([f32::MIN, f32::MAX]).into(),
dtype: DataType::Int32,
expected: Tensor::from([i32::MIN, i32::MAX]).into(),
},
];
cases.test_each(|case| {
let cast_op = Cast { to: case.dtype };
let result: Value = cast_op.run_simple(&case.input).unwrap();
assert_eq!(result, case.expected);
let input_dtype = match case.input.dtype() {
ValueType::Tensor(dtype) => Some(dtype),
_ => None,
};
if input_dtype.unwrap().size() == case.dtype.size() {
let result: Value = cast_op
.run_simple_in_place(case.input.clone(), InputList::new())
.unwrap();
assert_eq!(result, case.expected);
}
})
}
#[test]
fn test_cast_like() {
#[derive(Debug)]
struct Case {
input: Value,
other: Value,
expected: Value,
}
let cases = [
Case {
input: Tensor::from([0i32, 1, 2]).into(),
other: Tensor::from([0f32]).into(),
expected: Tensor::from([0., 1., 2.]).into(),
},
];
cases.test_each(|case| {
let cast_op = CastLike {};
let result: Value = cast_op.run_simple((&case.input, &case.other)).unwrap();
assert_eq!(result, case.expected);
})
}
}