1use rten_base::byte_cast::{Pod, cast_pod_vec};
2use rten_base::num;
3
4use rten_tensor::Tensor;
5use rten_tensor::prelude::*;
6
7use crate::buffer_pool::BufferPool;
8use crate::ops::{
9 DataType, IntoOpResult, OpError, OpRunContext, Operator, OutputList, Value, ValueView,
10};
11
12fn cast(pool: &BufferPool, input: ValueView, dtype: DataType) -> Result<Value, OpError> {
13 macro_rules! cast_as {
14 ($x:ident) => {
15 Ok($x.to_tensor_in(pool).into())
16 };
17
18 ($x:ident, $dest_ty:ty) => {
19 Ok($x.map_in(pool, |x| *x as $dest_ty).into())
20 };
21 }
22
23 match dtype {
24 DataType::Int32 => match input {
25 ValueView::Int32Tensor(t) => cast_as!(t),
26 ValueView::FloatTensor(t) => cast_as!(t, i32),
27 ValueView::Int8Tensor(t) => cast_as!(t, i32),
28 ValueView::UInt8Tensor(t) => cast_as!(t, i32),
29
30 ValueView::Sequence(_) => Err(OpError::UnsupportedType),
33 },
34 DataType::Float => match input {
35 ValueView::FloatTensor(t) => cast_as!(t),
36 ValueView::Int32Tensor(t) => cast_as!(t, f32),
37 ValueView::Int8Tensor(t) => cast_as!(t, f32),
38 ValueView::UInt8Tensor(t) => cast_as!(t, f32),
39 ValueView::Sequence(_) => Err(OpError::UnsupportedType),
40 },
41 DataType::Int8 => match input {
42 ValueView::Int8Tensor(t) => cast_as!(t),
43 ValueView::FloatTensor(t) => cast_as!(t, i8),
44 ValueView::Int32Tensor(t) => cast_as!(t, i8),
45 ValueView::UInt8Tensor(t) => cast_as!(t, i8),
46 ValueView::Sequence(_) => Err(OpError::UnsupportedType),
47 },
48 DataType::UInt8 => match input {
49 ValueView::UInt8Tensor(t) => cast_as!(t),
50 ValueView::FloatTensor(t) => cast_as!(t, u8),
51 ValueView::Int32Tensor(t) => cast_as!(t, u8),
52 ValueView::Int8Tensor(t) => cast_as!(t, u8),
53 ValueView::Sequence(_) => Err(OpError::UnsupportedType),
54 },
55 }
56}
57
58fn cast_tensor<T, U>(mut data: Tensor<T>) -> Tensor<U>
62where
63 T: Pod + num::Cast<U>,
64 U: Pod<Bytes = T::Bytes>,
65{
66 data.apply(|x| num::Cast::<U>::cast(*x).cast_bytes());
68
69 let shape = data.shape().to_vec();
71 let data = cast_pod_vec::<T, U>(data.into_data()).unwrap();
72 Tensor::from_data(&shape, data)
73}
74
75fn cast_in_place(input: Value, dtype: DataType) -> Result<Value, Value> {
78 match dtype {
79 DataType::Int32 => match input {
80 Value::Int32Tensor(t) => Ok(t.into()),
81 Value::FloatTensor(t) => Ok(cast_tensor::<_, i32>(t).into()),
82 _ => Err(input),
83 },
84 DataType::Float => match input {
85 Value::FloatTensor(t) => Ok(t.into()),
86 Value::Int32Tensor(t) => Ok(cast_tensor::<_, f32>(t).into()),
87 _ => Err(input),
88 },
89 DataType::Int8 => match input {
90 Value::Int8Tensor(t) => Ok(t.into()),
91 Value::UInt8Tensor(t) => Ok(cast_tensor::<_, i8>(t).into()),
92 _ => Err(input),
93 },
94 DataType::UInt8 => match input {
95 Value::UInt8Tensor(t) => Ok(t.into()),
96 Value::Int8Tensor(t) => Ok(cast_tensor::<_, u8>(t).into()),
97 _ => Err(input),
98 },
99 }
100}
101
102#[derive(Debug)]
103pub struct Cast {
104 pub to: DataType,
105}
106
107impl Operator for Cast {
108 fn name(&self) -> &str {
109 "Cast"
110 }
111
112 fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
113 let input = ctx.inputs().require(0)?;
114 cast(ctx.pool(), input, self.to).into_op_result()
115 }
116
117 fn can_run_in_place(&self) -> bool {
118 true
121 }
122
123 fn run_in_place(&self, input: Value, ctx: &OpRunContext) -> Result<Value, OpError> {
124 match cast_in_place(input, self.to) {
125 Ok(output) => Ok(output),
126 Err(input) => {
127 let converted = cast(ctx.pool(), input.as_view(), self.to)?;
128 input.add_to_pool(ctx.pool());
129 Ok(converted)
130 }
131 }
132 }
133}
134
135#[derive(Debug)]
136pub struct CastLike {}
137
138impl Operator for CastLike {
139 fn name(&self) -> &str {
140 "CastLike"
141 }
142
143 fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
144 let to = ctx.inputs().require(1)?.dtype();
145 Cast { to }.run(ctx)
146 }
147
148 fn can_run_in_place(&self) -> bool {
149 true
150 }
151
152 fn run_in_place(&self, input: Value, ctx: &OpRunContext) -> Result<Value, OpError> {
153 let to = ctx.inputs().require(0)?.dtype();
154 Cast { to }.run_in_place(input, ctx)
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use rten_tensor::Tensor;
161 use rten_testing::TestCases;
162
163 use crate::ops::{Cast, CastLike, DataType, InputList, OperatorExt, Value};
164
165 #[test]
166 fn test_cast() {
167 #[derive(Debug)]
168 struct Case {
169 input: Value,
170 dtype: DataType,
171 expected: Value,
172 }
173
174 let cases = [
175 Case {
177 input: Tensor::from([1, 2, 3]).into(),
178 dtype: DataType::Float,
179 expected: Tensor::from([1., 2., 3.]).into(),
180 },
181 Case {
183 input: Tensor::from([1, 2, 3]).into(),
184 dtype: DataType::Int32,
185 expected: Tensor::from([1, 2, 3]).into(),
186 },
187 Case {
189 input: Tensor::from([i8::MIN as i32, 0, i8::MAX as i32]).into(),
190 dtype: DataType::Int8,
191 expected: Tensor::from([i8::MIN, 0, i8::MAX]).into(),
192 },
193 Case {
195 input: Tensor::from([u8::MIN as i32, 0, u8::MAX as i32]).into(),
196 dtype: DataType::UInt8,
197 expected: Tensor::from([u8::MIN, 0, u8::MAX]).into(),
198 },
199 Case {
201 input: Tensor::from([1., 2., 3.]).into(),
202 dtype: DataType::Int32,
203 expected: Tensor::from([1, 2, 3]).into(),
204 },
205 Case {
207 input: Tensor::from([1., 2., 3.]).into(),
208 dtype: DataType::Float,
209 expected: Tensor::from([1., 2., 3.]).into(),
210 },
211 Case {
213 input: Tensor::from([i32::MIN, i32::MAX]).into(),
214 dtype: DataType::Float,
215 expected: Tensor::from([-2147483600.0, 2147483600.0]).into(),
216 },
217 Case {
224 input: Tensor::from([f32::MIN, f32::MAX]).into(),
225 dtype: DataType::Int32,
226 expected: Tensor::from([i32::MIN, i32::MAX]).into(),
227 },
228 ];
229
230 cases.test_each(|case| {
231 let cast_op = Cast { to: case.dtype };
233 let result: Value = cast_op.run_simple(&case.input).unwrap();
234 assert_eq!(result, case.expected);
235
236 if case.input.dtype().size() == case.dtype.size() {
238 let result: Value = cast_op
239 .run_simple_in_place(case.input.clone(), InputList::new())
240 .unwrap();
241 assert_eq!(result, case.expected);
242 }
243 })
244 }
245
246 #[test]
247 fn test_cast_like() {
248 #[derive(Debug)]
249 struct Case {
250 input: Value,
251 other: Value,
252 expected: Value,
253 }
254
255 let cases = [
259 Case {
261 input: Tensor::from([0i32, 1, 2]).into(),
262 other: Tensor::from([0f32]).into(),
263 expected: Tensor::from([0., 1., 2.]).into(),
264 },
265 ];
266
267 cases.test_each(|case| {
268 let cast_op = CastLike {};
269 let result: Value = cast_op.run_simple((&case.input, &case.other)).unwrap();
270 assert_eq!(result, case.expected);
271 })
272 }
273}