1use smallvec::SmallVec;
4
5pub use crate::{
6 sym_expr::SymExpr,
7 sym_gen::SymbolGen,
8 sym_tensor::{Constant, SymTensor},
9};
10
11#[derive(Clone, Debug, PartialEq)]
13pub enum InferShapesError {
14 IncorrectInputCount,
16
17 IncompatibleShapes,
21
22 IncorrectRank,
24
25 InvalidValue,
27
28 UnknownOutputCount,
30}
31
32pub trait InferShapes {
34 fn infer_shapes(
41 &self,
42 inputs: &[SymTensor],
43 sym_gen: &mut SymbolGen,
44 ) -> Result<Vec<SymTensor>, InferShapesError>;
45}
46
47pub struct UnaryOp;
54
55impl InferShapes for UnaryOp {
56 fn infer_shapes(
57 &self,
58 inputs: &[SymTensor],
59 _sym_gen: &mut SymbolGen,
60 ) -> Result<Vec<SymTensor>, InferShapesError> {
61 let Some(data) = inputs.first() else {
62 return Err(InferShapesError::IncorrectInputCount);
63 };
64
65 let shape = if let Some(shape) = data.shape() {
66 SymTensor::from_shape(shape.collect())
67 } else {
68 SymTensor::unknown("unknown input shape")
69 };
70
71 Ok([shape].into())
72 }
73}
74
75pub struct BinaryOp;
81
82impl InferShapes for BinaryOp {
83 fn infer_shapes(
84 &self,
85 inputs: &[SymTensor],
86 _sym_gen: &mut SymbolGen,
87 ) -> Result<Vec<SymTensor>, InferShapesError> {
88 let [a, b] = inputs else {
89 return Err(InferShapesError::IncorrectInputCount);
90 };
91
92 let (Some(a_dims), Some(b_dims)) = (a.shape(), b.shape()) else {
93 return Ok([SymTensor::unknown("unknown input shape")].into());
94 };
95
96 let a_pad = b_dims.len().saturating_sub(a_dims.len());
97 let b_pad = a_dims.len().saturating_sub(b_dims.len());
98 let mut out_shape: Vec<SymExpr> = Vec::with_capacity(a_pad + a_dims.len());
99
100 let a_iter = std::iter::repeat_n(SymExpr::Value(1), a_pad).chain(a_dims);
101 let b_iter = std::iter::repeat_n(SymExpr::Value(1), b_pad).chain(b_dims);
102
103 for (a, b) in a_iter.zip(b_iter) {
104 let dim: SymExpr = match (a, b) {
105 (a, b) if a == b => a.clone(),
106
107 (SymExpr::Value(1), b) => b.clone(),
110 (a, SymExpr::Value(1)) => a.clone(),
111
112 (SymExpr::Value(_), SymExpr::Value(_)) => {
115 return Err(InferShapesError::IncompatibleShapes);
116 }
117
118 (SymExpr::Var(_a), SymExpr::Value(b)) => SymExpr::Value(b),
122 (SymExpr::Value(a), SymExpr::Var(_b)) => SymExpr::Value(a),
123
124 (a, b) => a.broadcast(&b),
136 };
137 out_shape.push(dim);
138 }
139
140 Ok([SymTensor::from_shape(out_shape)].into())
141 }
142}
143
144pub struct VariadicOp;
150
151impl InferShapes for VariadicOp {
152 fn infer_shapes(
153 &self,
154 inputs: &[SymTensor],
155 sym_gen: &mut SymbolGen,
156 ) -> Result<Vec<SymTensor>, InferShapesError> {
157 if inputs.is_empty() {
158 return Err(InferShapesError::IncorrectInputCount);
159 }
160
161 let first_shape = inputs[0]
162 .shape()
163 .map(|shape| SymTensor::from_shape(shape.collect()))
164 .unwrap_or_else(|| SymTensor::unknown("unknown input shape"));
165
166 let out_shape: Result<SymTensor, InferShapesError> =
167 inputs
168 .iter()
169 .skip(1)
170 .try_fold(first_shape, |out_shape, in_shape| {
171 let mut shapes =
172 BinaryOp.infer_shapes(&[out_shape, in_shape.clone()], sym_gen)?;
173 Ok(shapes.remove(0))
174 });
175
176 Ok([out_shape?].into())
177 }
178}
179
180#[derive(Clone, Debug, PartialEq)]
182pub struct ReductionOp<'a> {
183 pub axes: Option<&'a [i32]>,
188
189 pub keep_dims: bool,
192}
193
194impl InferShapes for ReductionOp<'_> {
195 fn infer_shapes(
196 &self,
197 inputs: &[SymTensor],
198 _sym_gen: &mut SymbolGen,
199 ) -> Result<Vec<SymTensor>, InferShapesError> {
200 match inputs.len() {
201 1 | 2 => {}
202 _ => {
203 return Err(InferShapesError::IncorrectInputCount);
204 }
205 }
206
207 let data = &inputs[0];
208
209 let Some(data_dims) = data.shape() else {
210 return Ok([SymTensor::unknown("unknown input shape")].into());
211 };
212
213 let ndim = data_dims.len();
214 let mut axes: SmallVec<[usize; 4]> =
215 if let Some(Constant::Vector(axes)) = inputs.get(1).and_then(|x| x.to_constant()) {
216 resolve_axes(ndim, axes.iter()).map_err(|_| InferShapesError::IncorrectRank)?
217 } else if let Some(axes) = self.axes {
218 resolve_axes(ndim, axes.iter()).map_err(|_| InferShapesError::IncorrectRank)?
219 } else {
220 (0..ndim).collect()
221 };
222 axes.sort();
223 axes.dedup();
224
225 let out_ndim = if self.keep_dims {
226 ndim
227 } else {
228 ndim - axes.len()
229 };
230 let mut out_shape = Vec::with_capacity(out_ndim);
231
232 for (i, dim) in data_dims.enumerate() {
233 if !axes.contains(&i) {
234 out_shape.push(dim.clone());
235 continue;
236 } else if self.keep_dims {
237 out_shape.push(SymExpr::Value(1));
238 }
239 }
240
241 Ok([SymTensor::from_shape(out_shape)].into())
242 }
243}
244
245fn resolve_index(len: usize, index: i32) -> Option<usize> {
248 let len = len.min(i32::MAX as usize) as i32;
249 if index < -len || index >= len {
250 return None;
251 }
252
253 if index >= 0 {
254 Some(index as usize)
255 } else {
256 Some((len + index) as usize)
257 }
258}
259
260pub(crate) fn resolve_axis(ndim: usize, axis: i32) -> Result<usize, InferShapesError> {
265 resolve_index(ndim, axis).ok_or(InferShapesError::IncorrectRank)
266}
267
268fn resolve_axes<'a, I: ExactSizeIterator<Item = &'a i32>>(
273 ndim: usize,
274 axes: I,
275) -> Result<SmallVec<[usize; 4]>, InferShapesError> {
276 let mut resolved_axes = SmallVec::with_capacity(axes.len());
277 for axis in axes {
278 let resolved = resolve_axis(ndim, *axis)?;
279 resolved_axes.push(resolved);
280 }
281 Ok(resolved_axes)
282}
283
284#[cfg(test)]
285mod tests {
286 use rten_testing::TestCases;
287
288 use super::{
289 BinaryOp, InferShapes, InferShapesError, ReductionOp, SymExpr, SymTensor, SymbolGen,
290 UnaryOp, VariadicOp,
291 };
292 use crate::sym_tensor::{sym_elems, sym_shape};
293
294 #[test]
295 fn test_unary_op_infer() {
296 let input = sym_shape!("batch", 16, "seq", 24);
297 let mut sym_gen = SymbolGen::new();
298 let shape = UnaryOp
299 .infer_shapes(&[input.clone()], &mut sym_gen)
300 .unwrap();
301 assert_eq!(shape.len(), 1);
302 assert_eq!(shape[0], input);
303
304 let err = UnaryOp.infer_shapes(&[], &mut sym_gen).err().unwrap();
305 assert_eq!(err, InferShapesError::IncorrectInputCount);
306 }
307
308 #[test]
309 fn test_binary_op() {
310 #[derive(Debug)]
311 struct Case {
312 lhs: SymTensor,
313 rhs: SymTensor,
314 expected: SymTensor,
315 }
316
317 let cases = [
318 Case {
319 lhs: sym_shape!("batch"),
320 rhs: sym_shape!("batch"),
321 expected: sym_shape!("batch"),
322 },
323 Case {
324 lhs: sym_shape!(2, 3),
325 rhs: sym_shape!(2, 3),
326 expected: sym_shape!(2, 3),
327 },
328 Case {
329 lhs: sym_shape!(1, 5),
330 rhs: sym_shape!(4, 1),
331 expected: sym_shape!(4, 5),
332 },
333 Case {
334 lhs: sym_shape!(1, 1),
335 rhs: sym_shape!(1, 1),
336 expected: sym_shape!(1, 1),
337 },
338 Case {
339 lhs: sym_shape!(1, "bar"),
340 rhs: sym_shape!("foo", 1),
341 expected: sym_shape!("foo", "bar"),
342 },
343 Case {
344 lhs: sym_shape!("foo"),
345 rhs: sym_shape!("bar"),
346 expected: sym_shape!(SymExpr::from("foo").broadcast(&SymExpr::from("bar"))),
347 },
348 ];
349
350 cases.test_each(|case| {
351 let mut sym_gen = SymbolGen::new();
352 let shape = BinaryOp
353 .infer_shapes(&[case.lhs.clone(), case.rhs.clone()], &mut sym_gen)
354 .unwrap();
355 assert_eq!(shape.len(), 1);
356 assert_eq!(shape[0], case.expected.clone());
357 });
358 }
359
360 #[test]
361 fn test_binary_op_invalid() {
362 #[derive(Clone, Debug)]
363 struct Case {
364 inputs: Vec<Vec<SymExpr>>,
365 expected: InferShapesError,
366 }
367
368 let cases = [
369 Case {
370 inputs: [sym_elems!(5)].into(),
371 expected: InferShapesError::IncorrectInputCount,
372 },
373 Case {
374 inputs: [sym_elems!(5), sym_elems!(3)].into(),
375 expected: InferShapesError::IncompatibleShapes,
376 },
377 ];
378
379 cases.test_each_clone(|case| {
380 let mut sym_gen = SymbolGen::new();
381 let inputs: Vec<_> = case.inputs.into_iter().map(SymTensor::from_shape).collect();
382 let err = BinaryOp.infer_shapes(&inputs, &mut sym_gen).err().unwrap();
383 assert_eq!(err, case.expected);
384 });
385 }
386
387 #[test]
388 fn test_variadic_op() {
389 let mut sym_gen = SymbolGen::new();
390 let a = sym_shape!("batch", 4, 1, 1);
391 let b = sym_shape!("batch", 1, 8, 1);
392 let c = sym_shape!("batch", 1, 8, 16);
393
394 let result = VariadicOp.infer_shapes(&[a.clone()], &mut sym_gen).unwrap();
396 assert_eq!(result[0], sym_shape!("batch", 4, 1, 1));
397
398 let result = VariadicOp
400 .infer_shapes(&[a.clone(), b, c], &mut sym_gen)
401 .unwrap();
402 assert_eq!(result[0], sym_shape!("batch", 4, 8, 16));
403 }
404
405 #[test]
406 fn test_reduction_op() {
407 #[derive(Clone, Debug)]
408 struct Case<'a> {
409 inputs: Vec<SymTensor>,
410 op: ReductionOp<'a>,
411 expected: Vec<SymExpr>,
412 }
413
414 let axes = vec![SymExpr::Value(1i32)];
415
416 let default_op = ReductionOp {
417 axes: None,
418 keep_dims: false,
419 };
420
421 let cases = [
422 Case {
424 inputs: [
425 SymTensor::from_shape(sym_elems!("batch", 4, 5)),
426 SymTensor::from_vec(axes.clone()),
427 ]
428 .into(),
429 op: default_op.clone(),
430 expected: sym_elems!("batch", 5),
431 },
432 Case {
434 inputs: [SymTensor::from_shape(sym_elems!("batch", 4, 5))].into(),
435 op: ReductionOp {
436 axes: Some(&[1i32]),
437 ..default_op
438 },
439 expected: sym_elems!("batch", 5),
440 },
441 Case {
443 inputs: [
444 SymTensor::from_shape(sym_elems!("batch", 4, 5)),
445 SymTensor::from_vec(axes.clone()),
446 ]
447 .into(),
448 op: ReductionOp {
449 keep_dims: true,
450 ..default_op
451 },
452 expected: sym_elems!("batch", 1, 5),
453 },
454 Case {
456 inputs: [SymTensor::from_shape(sym_elems!(3, 4, 5))].into(),
457 op: default_op.clone(),
458 expected: sym_elems!(),
459 },
460 ];
461
462 cases.test_each(|case| {
463 let mut sym_gen = SymbolGen::new();
464 let shapes = case.op.infer_shapes(&case.inputs, &mut sym_gen).unwrap();
465 assert_eq!(shapes.len(), 1);
466 assert_eq!(shapes[0], SymTensor::from_shape(case.expected.clone()));
467 });
468 }
469}