rten 0.24.0

Machine learning runtime
Documentation
use rten_base::byte_cast::{Pod, cast_pod_vec};
use rten_base::num;

use rten_tensor::Tensor;
use rten_tensor::prelude::*;

use crate::buffer_pool::BufferPool;
use crate::infer_shapes::{InferShapes, InferShapesError, SymTensor, SymbolGen, UnaryOp};
use crate::operator::{
    IntoOpResult, OpError, OpRunContext, Operator, OutputList, OutputType, OutputTypeList,
    OutputTypesContext,
};
use crate::value::{DataType, Value, ValueType, ValueView};

fn cast(pool: &BufferPool, input: ValueView, dtype: DataType) -> Result<Value, OpError> {
    macro_rules! cast_as {
        ($x:ident) => {
            Ok($x.to_tensor_in(pool).into())
        };

        ($x:ident, $dest_ty:ty) => {
            Ok($x.map_in(pool, |x| *x as $dest_ty).into())
        };
    }

    match dtype {
        DataType::Int32 => match input {
            ValueView::Int32Tensor(t) => cast_as!(t),
            ValueView::FloatTensor(t) => cast_as!(t, i32),
            ValueView::Int8Tensor(t) => cast_as!(t, i32),
            ValueView::UInt8Tensor(t) => cast_as!(t, i32),

            // The ONNX Cast op doesn't support sequences, although logically
            // this could be supported by casting each tensor in the sequence.
            ValueView::Sequence(_) => Err(OpError::UnsupportedType),
        },
        DataType::Float => match input {
            ValueView::FloatTensor(t) => cast_as!(t),
            ValueView::Int32Tensor(t) => cast_as!(t, f32),
            ValueView::Int8Tensor(t) => cast_as!(t, f32),
            ValueView::UInt8Tensor(t) => cast_as!(t, f32),
            ValueView::Sequence(_) => Err(OpError::UnsupportedType),
        },
        DataType::Int8 => match input {
            ValueView::Int8Tensor(t) => cast_as!(t),
            ValueView::FloatTensor(t) => cast_as!(t, i8),
            ValueView::Int32Tensor(t) => cast_as!(t, i8),
            ValueView::UInt8Tensor(t) => cast_as!(t, i8),
            ValueView::Sequence(_) => Err(OpError::UnsupportedType),
        },
        DataType::UInt8 => match input {
            ValueView::UInt8Tensor(t) => cast_as!(t),
            ValueView::FloatTensor(t) => cast_as!(t, u8),
            ValueView::Int32Tensor(t) => cast_as!(t, u8),
            ValueView::Int8Tensor(t) => cast_as!(t, u8),
            ValueView::Sequence(_) => Err(OpError::UnsupportedType),
        },
    }
}

/// Cast a tensor from type T to U in-place.
///
/// Both T and U must have the same size.
fn cast_tensor<T, U>(mut data: Tensor<T>) -> Tensor<U>
where
    T: Pod + num::Cast<U>,
    U: Pod<Bytes = T::Bytes>,
{
    // Cast elements from type T to U in place.
    data.apply(|x| num::Cast::<U>::cast(*x).cast_bytes());

    // Extract the converted data and transmute from T to U.
    let shape = data.shape().to_vec();
    let data = cast_pod_vec::<T, U>(data.into_data()).unwrap();
    Tensor::from_data(&shape, data)
}

/// Cast elements of `input` to a given dtype in place, or return the input
/// value if the cast is not possible.
fn cast_in_place(input: Value, dtype: DataType) -> Result<Value, Value> {
    match dtype {
        DataType::Int32 => match input {
            Value::Int32Tensor(t) => Ok(t.into()),
            Value::FloatTensor(t) => Ok(cast_tensor::<_, i32>(t).into()),
            _ => Err(input),
        },
        DataType::Float => match input {
            Value::FloatTensor(t) => Ok(t.into()),
            Value::Int32Tensor(t) => Ok(cast_tensor::<_, f32>(t).into()),
            _ => Err(input),
        },
        DataType::Int8 => match input {
            Value::Int8Tensor(t) => Ok(t.into()),
            Value::UInt8Tensor(t) => Ok(cast_tensor::<_, i8>(t).into()),
            _ => Err(input),
        },
        DataType::UInt8 => match input {
            Value::UInt8Tensor(t) => Ok(t.into()),
            Value::Int8Tensor(t) => Ok(cast_tensor::<_, u8>(t).into()),
            _ => Err(input),
        },
    }
}

#[derive(Debug)]
pub struct Cast {
    pub to: DataType,
}

impl Operator for Cast {
    fn name(&self) -> &str {
        "Cast"
    }

    fn max_inputs(&self) -> Option<usize> {
        Some(1)
    }

    fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
        let input = ctx.inputs().require(0)?;
        cast(ctx.pool(), input, self.to).into_op_result()
    }

    fn can_run_in_place(&self) -> bool {
        // Cast can run in place if the input's dtype already matches `self.to`
        // or both dtypes have the same element size.
        true
    }

    fn run_in_place(&self, input: Value, ctx: &OpRunContext) -> Result<Value, OpError> {
        match cast_in_place(input, self.to) {
            Ok(output) => Ok(output),
            Err(input) => {
                let converted = cast(ctx.pool(), input.as_view(), self.to)?;
                input.add_to_pool(ctx.pool());
                Ok(converted)
            }
        }
    }

    fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
        Some(self)
    }

    fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
        Some([OutputType::Fixed(ValueType::Tensor(self.to))].into())
    }
}

impl InferShapes for Cast {
    fn infer_shapes(
        &self,
        inputs: &[SymTensor],
        _sym_gen: &mut SymbolGen,
    ) -> Result<Vec<SymTensor>, InferShapesError> {
        let Some(data) = inputs.first() else {
            return Err(InferShapesError::IncorrectInputCount);
        };

        // If this is a no-op cast from int to int, preserve symbolic values.
        // Otherwise preserve just the shape like a generic unary operator.
        let value = if data.values().is_some() && self.to == DataType::Int32 {
            data.clone()
        } else if let Some(shape) = data.shape() {
            SymTensor::from_shape(shape.collect())
        } else {
            SymTensor::unknown("unknown input shape")
        };

        Ok([value].into())
    }
}

#[derive(Debug)]
pub struct CastLike {}

impl Operator for CastLike {
    fn name(&self) -> &str {
        "CastLike"
    }

    fn max_inputs(&self) -> Option<usize> {
        Some(2)
    }

    fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
        let target_type = ctx.inputs().require(1)?;
        let ValueType::Tensor(to) = target_type.dtype() else {
            return Err(OpError::InvalidValue("expected target_type to be a tensor"));
        };
        Cast { to }.run(ctx)
    }

    fn can_run_in_place(&self) -> bool {
        true
    }

    fn run_in_place(&self, input: Value, ctx: &OpRunContext) -> Result<Value, OpError> {
        let target_type = ctx.inputs().require(0)?;
        let ValueType::Tensor(to) = target_type.dtype() else {
            return Err(OpError::InvalidValue("expected target_type to be a tensor"));
        };
        Cast { to }.run_in_place(input, ctx)
    }

    fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
        Some([OutputType::CopyFromInput(1)].into())
    }

    fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
        Some(&UnaryOp)
    }
}

#[cfg(test)]
mod tests {
    use rten_tensor::Tensor;
    use rten_testing::TestCases;

    use super::{Cast, CastLike};
    use crate::operator::{InputList, OperatorExt};
    use crate::value::{DataType, Value, ValueType};

    #[test]
    fn test_cast() {
        #[derive(Debug)]
        struct Case {
            input: Value,
            dtype: DataType,
            expected: Value,
        }

        let cases = [
            // i32 -> f32
            Case {
                input: Tensor::from([1, 2, 3]).into(),
                dtype: DataType::Float,
                expected: Tensor::from([1., 2., 3.]).into(),
            },
            // i32 -> i32
            Case {
                input: Tensor::from([1, 2, 3]).into(),
                dtype: DataType::Int32,
                expected: Tensor::from([1, 2, 3]).into(),
            },
            // i32 -> i8
            Case {
                input: Tensor::from([i8::MIN as i32, 0, i8::MAX as i32]).into(),
                dtype: DataType::Int8,
                expected: Tensor::from([i8::MIN, 0, i8::MAX]).into(),
            },
            // i32 -> u8
            Case {
                input: Tensor::from([u8::MIN as i32, 0, u8::MAX as i32]).into(),
                dtype: DataType::UInt8,
                expected: Tensor::from([u8::MIN, 0, u8::MAX]).into(),
            },
            // f32 -> i32
            Case {
                input: Tensor::from([1., 2., 3.]).into(),
                dtype: DataType::Int32,
                expected: Tensor::from([1, 2, 3]).into(),
            },
            // f32 -> f32
            Case {
                input: Tensor::from([1., 2., 3.]).into(),
                dtype: DataType::Float,
                expected: Tensor::from([1., 2., 3.]).into(),
            },
            // Int -> float out of range. This will lose precision.
            Case {
                input: Tensor::from([i32::MIN, i32::MAX]).into(),
                dtype: DataType::Float,
                expected: Tensor::from([-2147483600.0, 2147483600.0]).into(),
            },
            // Float -> int out of range.
            //
            // In RTen this saturates following the behavior of Rust's `as`
            // operator. This is different than C++ / PyTorch / NumPy where
            // the behavior of such conversions is undefined.
            // See https://github.com/robertknight/rten/pull/387#issuecomment-2420343989.
            Case {
                input: Tensor::from([f32::MIN, f32::MAX]).into(),
                dtype: DataType::Int32,
                expected: Tensor::from([i32::MIN, i32::MAX]).into(),
            },
        ];

        cases.test_each(|case| {
            // Copying cast.
            let cast_op = Cast { to: case.dtype };
            let result: Value = cast_op.run_simple(&case.input).unwrap();
            assert_eq!(result, case.expected);

            let input_dtype = match case.input.dtype() {
                ValueType::Tensor(dtype) => Some(dtype),
                _ => None,
            };

            // In-place cast.
            if input_dtype.unwrap().size() == case.dtype.size() {
                let result: Value = cast_op
                    .run_simple_in_place(case.input.clone(), InputList::new())
                    .unwrap();
                assert_eq!(result, case.expected);
            }
        })
    }

    #[test]
    fn test_cast_like() {
        #[derive(Debug)]
        struct Case {
            input: Value,
            other: Value,
            expected: Value,
        }

        // `CastLike` uses the same conversions as the `Cast` operator,
        // so these tests don't check all data type combinations, only that the
        // target type is taken from the second argument.
        let cases = [
            // i32 -> f32
            Case {
                input: Tensor::from([0i32, 1, 2]).into(),
                other: Tensor::from([0f32]).into(),
                expected: Tensor::from([0., 1., 2.]).into(),
            },
        ];

        cases.test_each(|case| {
            let cast_op = CastLike {};
            let result: Value = cast_op.run_simple((&case.input, &case.other)).unwrap();
            assert_eq!(result, case.expected);
        })
    }
}