rten/ops/
convert.rs

1use rten_base::byte_cast::{Pod, cast_pod_vec};
2use rten_base::num;
3
4use rten_tensor::Tensor;
5use rten_tensor::prelude::*;
6
7use crate::buffer_pool::BufferPool;
8use crate::ops::{
9    DataType, IntoOpResult, OpError, OpRunContext, Operator, OutputList, Value, ValueView,
10};
11
12fn cast(pool: &BufferPool, input: ValueView, dtype: DataType) -> Result<Value, OpError> {
13    macro_rules! cast_as {
14        ($x:ident) => {
15            Ok($x.to_tensor_in(pool).into())
16        };
17
18        ($x:ident, $dest_ty:ty) => {
19            Ok($x.map_in(pool, |x| *x as $dest_ty).into())
20        };
21    }
22
23    match dtype {
24        DataType::Int32 => match input {
25            ValueView::Int32Tensor(t) => cast_as!(t),
26            ValueView::FloatTensor(t) => cast_as!(t, i32),
27            ValueView::Int8Tensor(t) => cast_as!(t, i32),
28            ValueView::UInt8Tensor(t) => cast_as!(t, i32),
29
30            // The ONNX Cast op doesn't support sequences, although logically
31            // this could be supported by casting each tensor in the sequence.
32            ValueView::Sequence(_) => Err(OpError::UnsupportedType),
33        },
34        DataType::Float => match input {
35            ValueView::FloatTensor(t) => cast_as!(t),
36            ValueView::Int32Tensor(t) => cast_as!(t, f32),
37            ValueView::Int8Tensor(t) => cast_as!(t, f32),
38            ValueView::UInt8Tensor(t) => cast_as!(t, f32),
39            ValueView::Sequence(_) => Err(OpError::UnsupportedType),
40        },
41        DataType::Int8 => match input {
42            ValueView::Int8Tensor(t) => cast_as!(t),
43            ValueView::FloatTensor(t) => cast_as!(t, i8),
44            ValueView::Int32Tensor(t) => cast_as!(t, i8),
45            ValueView::UInt8Tensor(t) => cast_as!(t, i8),
46            ValueView::Sequence(_) => Err(OpError::UnsupportedType),
47        },
48        DataType::UInt8 => match input {
49            ValueView::UInt8Tensor(t) => cast_as!(t),
50            ValueView::FloatTensor(t) => cast_as!(t, u8),
51            ValueView::Int32Tensor(t) => cast_as!(t, u8),
52            ValueView::Int8Tensor(t) => cast_as!(t, u8),
53            ValueView::Sequence(_) => Err(OpError::UnsupportedType),
54        },
55    }
56}
57
58/// Cast a tensor from type T to U in-place.
59///
60/// Both T and U must have the same size.
61fn cast_tensor<T, U>(mut data: Tensor<T>) -> Tensor<U>
62where
63    T: Pod + num::Cast<U>,
64    U: Pod<Bytes = T::Bytes>,
65{
66    // Cast elements from type T to U in place.
67    data.apply(|x| num::Cast::<U>::cast(*x).cast_bytes());
68
69    // Extract the converted data and transmute from T to U.
70    let shape = data.shape().to_vec();
71    let data = cast_pod_vec::<T, U>(data.into_data()).unwrap();
72    Tensor::from_data(&shape, data)
73}
74
75/// Cast elements of `input` to a given dtype in place, or return the input
76/// value if the cast is not possible.
77fn cast_in_place(input: Value, dtype: DataType) -> Result<Value, Value> {
78    match dtype {
79        DataType::Int32 => match input {
80            Value::Int32Tensor(t) => Ok(t.into()),
81            Value::FloatTensor(t) => Ok(cast_tensor::<_, i32>(t).into()),
82            _ => Err(input),
83        },
84        DataType::Float => match input {
85            Value::FloatTensor(t) => Ok(t.into()),
86            Value::Int32Tensor(t) => Ok(cast_tensor::<_, f32>(t).into()),
87            _ => Err(input),
88        },
89        DataType::Int8 => match input {
90            Value::Int8Tensor(t) => Ok(t.into()),
91            Value::UInt8Tensor(t) => Ok(cast_tensor::<_, i8>(t).into()),
92            _ => Err(input),
93        },
94        DataType::UInt8 => match input {
95            Value::UInt8Tensor(t) => Ok(t.into()),
96            Value::Int8Tensor(t) => Ok(cast_tensor::<_, u8>(t).into()),
97            _ => Err(input),
98        },
99    }
100}
101
102#[derive(Debug)]
103pub struct Cast {
104    pub to: DataType,
105}
106
107impl Operator for Cast {
108    fn name(&self) -> &str {
109        "Cast"
110    }
111
112    fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
113        let input = ctx.inputs().require(0)?;
114        cast(ctx.pool(), input, self.to).into_op_result()
115    }
116
117    fn can_run_in_place(&self) -> bool {
118        // Cast can run in place if the input's dtype already matches `self.to`
119        // or both dtypes have the same element size.
120        true
121    }
122
123    fn run_in_place(&self, input: Value, ctx: &OpRunContext) -> Result<Value, OpError> {
124        match cast_in_place(input, self.to) {
125            Ok(output) => Ok(output),
126            Err(input) => {
127                let converted = cast(ctx.pool(), input.as_view(), self.to)?;
128                input.add_to_pool(ctx.pool());
129                Ok(converted)
130            }
131        }
132    }
133}
134
135#[derive(Debug)]
136pub struct CastLike {}
137
138impl Operator for CastLike {
139    fn name(&self) -> &str {
140        "CastLike"
141    }
142
143    fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
144        let to = ctx.inputs().require(1)?.dtype();
145        Cast { to }.run(ctx)
146    }
147
148    fn can_run_in_place(&self) -> bool {
149        true
150    }
151
152    fn run_in_place(&self, input: Value, ctx: &OpRunContext) -> Result<Value, OpError> {
153        let to = ctx.inputs().require(0)?.dtype();
154        Cast { to }.run_in_place(input, ctx)
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use rten_tensor::Tensor;
161    use rten_testing::TestCases;
162
163    use crate::ops::{Cast, CastLike, DataType, InputList, OperatorExt, Value};
164
165    #[test]
166    fn test_cast() {
167        #[derive(Debug)]
168        struct Case {
169            input: Value,
170            dtype: DataType,
171            expected: Value,
172        }
173
174        let cases = [
175            // i32 -> f32
176            Case {
177                input: Tensor::from([1, 2, 3]).into(),
178                dtype: DataType::Float,
179                expected: Tensor::from([1., 2., 3.]).into(),
180            },
181            // i32 -> i32
182            Case {
183                input: Tensor::from([1, 2, 3]).into(),
184                dtype: DataType::Int32,
185                expected: Tensor::from([1, 2, 3]).into(),
186            },
187            // i32 -> i8
188            Case {
189                input: Tensor::from([i8::MIN as i32, 0, i8::MAX as i32]).into(),
190                dtype: DataType::Int8,
191                expected: Tensor::from([i8::MIN, 0, i8::MAX]).into(),
192            },
193            // i32 -> u8
194            Case {
195                input: Tensor::from([u8::MIN as i32, 0, u8::MAX as i32]).into(),
196                dtype: DataType::UInt8,
197                expected: Tensor::from([u8::MIN, 0, u8::MAX]).into(),
198            },
199            // f32 -> i32
200            Case {
201                input: Tensor::from([1., 2., 3.]).into(),
202                dtype: DataType::Int32,
203                expected: Tensor::from([1, 2, 3]).into(),
204            },
205            // f32 -> f32
206            Case {
207                input: Tensor::from([1., 2., 3.]).into(),
208                dtype: DataType::Float,
209                expected: Tensor::from([1., 2., 3.]).into(),
210            },
211            // Int -> float out of range. This will lose precision.
212            Case {
213                input: Tensor::from([i32::MIN, i32::MAX]).into(),
214                dtype: DataType::Float,
215                expected: Tensor::from([-2147483600.0, 2147483600.0]).into(),
216            },
217            // Float -> int out of range.
218            //
219            // In RTen this saturates following the behavior of Rust's `as`
220            // operator. This is different than C++ / PyTorch / NumPy where
221            // the behavior of such conversions is undefined.
222            // See https://github.com/robertknight/rten/pull/387#issuecomment-2420343989.
223            Case {
224                input: Tensor::from([f32::MIN, f32::MAX]).into(),
225                dtype: DataType::Int32,
226                expected: Tensor::from([i32::MIN, i32::MAX]).into(),
227            },
228        ];
229
230        cases.test_each(|case| {
231            // Copying cast.
232            let cast_op = Cast { to: case.dtype };
233            let result: Value = cast_op.run_simple(&case.input).unwrap();
234            assert_eq!(result, case.expected);
235
236            // In-place cast.
237            if case.input.dtype().size() == case.dtype.size() {
238                let result: Value = cast_op
239                    .run_simple_in_place(case.input.clone(), InputList::new())
240                    .unwrap();
241                assert_eq!(result, case.expected);
242            }
243        })
244    }
245
246    #[test]
247    fn test_cast_like() {
248        #[derive(Debug)]
249        struct Case {
250            input: Value,
251            other: Value,
252            expected: Value,
253        }
254
255        // `CastLike` uses the same conversions as the `Cast` operator,
256        // so these tests don't check all data type combinations, only that the
257        // target type is taken from the second argument.
258        let cases = [
259            // i32 -> f32
260            Case {
261                input: Tensor::from([0i32, 1, 2]).into(),
262                other: Tensor::from([0f32]).into(),
263                expected: Tensor::from([0., 1., 2.]).into(),
264            },
265        ];
266
267        cases.test_each(|case| {
268            let cast_op = CastLike {};
269            let result: Value = cast_op.run_simple((&case.input, &case.other)).unwrap();
270            assert_eq!(result, case.expected);
271        })
272    }
273}