1use std::any::Any;
16use std::borrow::Cow;
17use std::convert::Infallible;
18use std::error::Error;
19use std::fmt;
20use std::fmt::{Debug, Display};
21
22use smallvec::SmallVec;
23
24use rten_gemm::PackedBMatrix;
25use rten_tensor::errors::DimensionError;
26use rten_tensor::{MutLayout, Storage, TensorBase};
27
28use crate::buffer_pool::BufferPool;
29use crate::graph::{CaptureEnv, Graph, RunError, RunOptions};
30use crate::timing::Profiler;
31use crate::value::{CastError, DataType, DataTypeOf, Value, ValueOrView, ValueView};
32use crate::weight_cache::WeightCache;
33
34mod attention;
35mod binary_elementwise;
36mod concat;
37mod control_flow;
38mod conv;
39mod conv_transpose;
40mod convert;
41mod einsum;
42mod gather;
43mod generate;
44mod grid_sample;
45mod identity;
46mod layout;
47mod matmul;
48mod non_max_suppression;
49mod norm;
50mod pad;
51mod pooling;
52mod quantize;
53
54#[cfg(feature = "fft")]
55mod fft;
56
57#[cfg(feature = "random")]
58mod random;
59
60mod reduce;
61mod resize;
62mod rnn;
63mod sequence;
64mod slice;
65mod split;
66mod trilu;
67mod unary_elementwise;
68mod variadic_elementwise;
69
70pub(crate) mod transform_inputs;
72
73pub use attention::AddSoftmax;
74pub use binary_elementwise::{
75 Add, And, Div, DivMode, Equal, Greater, GreaterOrEqual, Less, LessOrEqual, Mod, Mul, Or, Pow,
76 Sub, Where, Xor, add, and, div, equal, greater, greater_or_equal, less, less_or_equal, mod_op,
77 mul, or, pow, sub, where_op, xor,
78};
79pub use concat::{Concat, Tile, concat, tile};
80pub use control_flow::{If, Loop};
81pub use conv::{Conv, ConvInteger, conv, conv_integer};
82pub use conv_transpose::{ConvTranspose, conv_transpose};
83pub use convert::{Cast, CastLike};
84pub use einsum::{Einsum, einsum};
85pub use gather::{
86 Gather, GatherElements, GatherND, ScatterElements, ScatterND, ScatterReduction, gather,
87 gather_elements, gather_nd, scatter_elements, scatter_nd,
88};
89pub use generate::{ConstantOfShape, EyeLike, OneHot, Range, constant_of_shape, onehot, range};
90pub use grid_sample::GridSample;
91pub use identity::Identity;
92pub use layout::{
93 DepthToSpace, DepthToSpaceMode, Expand, Flatten, Reshape, Shape, Size, Squeeze, Transpose,
94 Unsqueeze, depth_to_space, expand, flatten, reshape, squeeze,
95};
96pub use matmul::{FusedMatMul, Gemm, MatMul, MatMulInteger, MatMulIntegerToFloat, gemm_op, matmul};
97pub use non_max_suppression::{BoxOrder, NonMaxSuppression, non_max_suppression};
98pub use norm::{
99 BatchNormalization, InstanceNormalization, LayerNormalization, LogSoftmax, RmsNormalization,
100 Softmax, batch_norm, instance_normalization, layer_normalization, log_softmax,
101 rms_normalization, softmax,
102};
103pub use pad::{Pad, PadMode, pad};
104pub use pooling::{
105 AveragePool, GlobalAveragePool, MaxPool, average_pool, global_average_pool, max_pool,
106};
107pub use quantize::{
108 DequantizeLinear, DynamicQuantizeLinear, QuantizeLinear, dequantize_linear,
109 dynamic_quantize_linear, quantize_linear,
110};
111
112#[cfg(feature = "fft")]
113pub use fft::{STFT, stft};
114
115#[cfg(feature = "random")]
116pub use random::{Dropout, RandomNormal, RandomNormalLike, RandomUniform, RandomUniformLike};
117
118pub use reduce::{
119 ArgMax, ArgMin, CumSum, NonZero, ReduceL2, ReduceMax, ReduceMean, ReduceMin, ReduceProd,
120 ReduceSum, ReduceSumSquare, TopK, arg_max, arg_min, cum_sum, nonzero, reduce_l2, reduce_max,
121 reduce_mean, reduce_min, reduce_prod, reduce_sum, reduce_sum_square, topk,
122};
123pub use resize::{
124 CoordTransformMode, NearestMode, Resize, ResizeMode, ResizeTarget, resize, resize_image,
125};
126pub use rnn::{Direction, GRU, LSTM, gru, lstm};
127pub use sequence::{
128 ConcatFromSequence, SequenceAt, SequenceConstruct, SequenceEmpty, SequenceErase,
129 SequenceInsert, SequenceLength, SplitToSequence,
130};
131pub use slice::{Slice, slice};
132pub use split::{Split, split};
133pub use trilu::{Trilu, trilu};
134pub use unary_elementwise::{
135 Abs, Acos, Asin, Atan, Ceil, Clip, Cos, Elu, Erf, Exp, Floor, Gelu, HardSigmoid, HardSwish,
136 IsInf, IsNaN, LeakyRelu, Log, Neg, Not, PRelu, Reciprocal, Relu, Round, Sigmoid, Sign, Silu,
137 Sin, Softplus, Sqrt, Swish, Tan, Tanh,
138};
139pub use variadic_elementwise::{Max, Mean, Min, Sum, max, mean, min, sum};
140
141mod operators;
142pub use operators::{FloatOperators, Operators};
143
144#[derive(Clone, Debug)]
145pub enum Padding {
146 Same,
153
154 Fixed(SmallVec<[usize; 4]>),
158}
159
160impl Padding {
161 pub fn zero<const N: usize>() -> Padding {
163 Padding::Fixed(SmallVec::from_elem(0, N * 2))
164 }
165
166 pub fn expand_1d_to_2d(&self) -> Result<Padding, OpError> {
168 match self {
169 Padding::Same => Ok(Padding::Same),
170 Padding::Fixed(pads) => match pads.as_slice() {
171 &[pad_start, pad_end] => Ok([0, pad_start, 0, pad_end].into()),
172 _ => Err(OpError::InvalidValue("expected 2 pad values")),
173 },
174 }
175 }
176}
177
178impl<S: AsRef<[usize]>> From<S> for Padding {
180 fn from(val: S) -> Padding {
181 Padding::Fixed(val.as_ref().into())
182 }
183}
184
185pub enum PrepackedInput {
188 FloatBMatrix(PackedBMatrix<f32>),
190
191 Int8BMatrix(PackedBMatrix<i8>),
193}
194
195impl PrepackedInput {
196 fn dtype(&self) -> DataType {
197 match self {
198 Self::FloatBMatrix(_) => DataType::Float,
199 Self::Int8BMatrix(_) => DataType::Int8,
200 }
201 }
202}
203
204macro_rules! impl_prepacked_input_conversions {
205 ($type:ty, $variant:ident) => {
206 impl From<PackedBMatrix<$type>> for PrepackedInput {
207 fn from(value: PackedBMatrix<$type>) -> Self {
208 PrepackedInput::$variant(value)
209 }
210 }
211
212 impl<'a> TryFrom<&'a PrepackedInput> for &'a PackedBMatrix<$type> {
213 type Error = CastError;
214
215 fn try_from(ppi: &'a PrepackedInput) -> Result<Self, Self::Error> {
216 match ppi {
217 PrepackedInput::$variant(packed) => Ok(packed),
218 _ => Err(CastError::WrongType {
219 actual: ppi.dtype(),
220 expected: <$type as DataTypeOf>::dtype_of(),
221 }),
222 }
223 }
224 }
225 };
226}
227impl_prepacked_input_conversions!(f32, FloatBMatrix);
228impl_prepacked_input_conversions!(i8, Int8BMatrix);
229
230#[deprecated = "renamed to `ValueOrView`"]
231pub type InputOrOutput<'a> = ValueOrView<'a>;
232
233#[deprecated = "renamed to `ValueView`"]
234pub type Input<'a> = ValueView<'a>;
235
236#[deprecated = "renamed to `Value`"]
237pub type Output = Value;
238
239pub trait IntoOpResult {
242 fn into_op_result(self) -> Result<OutputList, OpError>;
243}
244
245impl IntoOpResult for Result<Value, OpError> {
246 fn into_op_result(self) -> Result<OutputList, OpError> {
247 self.map(|out| [out].into())
248 }
249}
250
251impl IntoOpResult for Value {
252 fn into_op_result(self) -> Result<OutputList, OpError> {
253 Ok([self].into())
254 }
255}
256
257impl<S: Storage, L: MutLayout> IntoOpResult for TensorBase<S, L>
258where
259 Value: From<TensorBase<S, L>>,
260{
261 fn into_op_result(self) -> Result<OutputList, OpError> {
262 let output: Value = self.into();
263 Ok([output].into())
264 }
265}
266
267impl<S: Storage, L: MutLayout> IntoOpResult for Result<TensorBase<S, L>, OpError>
268where
269 Value: From<TensorBase<S, L>>,
270{
271 fn into_op_result(self) -> Result<OutputList, OpError> {
272 self.map(|tensor| [tensor.into()].into())
273 }
274}
275
276impl<T> IntoOpResult for Result<Vec<T>, OpError>
277where
278 Value: From<T>,
279{
280 fn into_op_result(self) -> Result<OutputList, OpError> {
281 self.map(|tensors| tensors.into_iter().map(|t| t.into()).collect())
282 }
283}
284
285#[derive(Eq, PartialEq, Debug)]
287pub enum OpError {
288 CastFailed(CastError),
290
291 InputCastFailed { index: usize, error: CastError },
293
294 UnsupportedType,
296
297 IncompatibleInputShapes(&'static str),
300
301 MissingInputs,
303
304 InvalidValue(&'static str),
306
307 UnsupportedValue(&'static str),
309}
310
311impl OpError {
312 fn with_input_index(self, index: usize) -> OpError {
314 match self {
315 Self::CastFailed(error) => OpError::InputCastFailed { index, error },
316 Self::InputCastFailed { error, .. } => OpError::InputCastFailed { index, error },
317 other => other,
318 }
319 }
320}
321
322impl From<DimensionError> for OpError {
323 fn from(val: DimensionError) -> OpError {
324 OpError::CastFailed(val.into())
325 }
326}
327
328impl From<CastError> for OpError {
329 fn from(val: CastError) -> OpError {
330 OpError::CastFailed(val)
331 }
332}
333
334impl From<Infallible> for OpError {
335 fn from(x: Infallible) -> OpError {
336 match x {}
337 }
338}
339
340impl Display for OpError {
341 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
342 match self {
343 OpError::CastFailed(err) => write!(f, "{}", err),
344 OpError::InputCastFailed { index, error } => {
345 write!(f, "conversion error for input {}: {}", index, error)
346 }
347 OpError::IncompatibleInputShapes(details) => {
348 write!(f, "incompatible input shapes: {}", details)
349 }
350 OpError::MissingInputs => write!(f, "required inputs were missing"),
351 OpError::InvalidValue(details) => {
352 write!(f, "input or attribute has invalid value: {}", details)
353 }
354 OpError::UnsupportedValue(details) => {
355 write!(f, "unsupported input or attribute value: {}", details)
356 }
357 OpError::UnsupportedType => {
358 write!(f, "unsupported input type")
359 }
360 }
361 }
362}
363
364impl Error for OpError {}
365
366macro_rules! static_dims {
373 ($tensor:ident, $ndim:literal, $dim_names:literal) => {{
374 use rten_tensor::prelude::*;
375
376 if $tensor.ndim() != $ndim {
377 Err(OpError::InvalidValue(concat!(
378 stringify!($tensor),
379 " must have ",
380 stringify!($ndim),
381 " dims (",
382 $dim_names,
383 ")"
384 )))
385 } else {
386 Ok($tensor.nd_view::<$ndim>())
387 }
388 }};
389
390 ($tensor:ident, $ndim:literal) => {{
391 use rten_tensor::prelude::*;
392
393 if $tensor.ndim() != $ndim {
394 Err(OpError::InvalidValue(concat!(
395 stringify!($tensor),
396 " must have ",
397 stringify!($ndim),
398 " dims"
399 )))
400 } else {
401 Ok($tensor.nd_view::<$ndim>())
402 }
403 }};
404
405 ($tensor:ident?, $ndim: expr) => {
406 if let Some($tensor) = $tensor.as_ref() {
407 Some(static_dims!($tensor, $ndim))
408 } else {
409 None
410 }
411 };
412}
413
414pub(crate) use static_dims;
415
416pub struct OpRunContext<'a, 'i> {
419 pool: &'a BufferPool,
420 inputs: &'a InputList<'i>,
421 n_outputs: Option<u32>,
422 name: Option<&'a str>,
423}
424
425impl<'a, 'i> OpRunContext<'a, 'i> {
426 pub fn new(pool: &'a BufferPool, inputs: &'a InputList<'i>) -> Self {
427 OpRunContext {
428 pool,
429 inputs,
430 n_outputs: None,
431 name: None,
432 }
433 }
434
435 pub fn with_new_inputs<'b, 'il>(&self, inputs: &'b InputList<'il>) -> OpRunContext<'b, 'il>
439 where
440 'a: 'b,
441 {
442 OpRunContext { inputs, ..*self }
443 }
444
445 pub fn pool(&self) -> &BufferPool {
447 self.pool
448 }
449
450 pub fn inputs(&self) -> &InputList<'i> {
455 self.inputs
456 }
457
458 pub fn set_num_outputs(&mut self, n: u32) {
464 self.n_outputs = Some(n);
465 }
466
467 pub fn num_outputs(&self) -> Option<u32> {
470 self.n_outputs
471 }
472
473 pub fn set_name(&mut self, name: Option<&'a str>) {
475 self.name = name;
476 }
477
478 pub fn name(&self) -> Option<&str> {
480 self.name
481 }
482}
483
484pub type OutputList = SmallVec<[Value; 1]>;
489
490pub trait Operator: Any + Debug {
498 fn name(&self) -> &str;
500
501 fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError>;
510
511 fn can_run_in_place(&self) -> bool {
518 false
519 }
520
521 fn is_commutative(&self) -> bool {
527 false
528 }
529
530 fn is_deterministic(&self) -> bool {
540 true
541 }
542
543 fn run_in_place(
556 &self,
557 #[allow(unused)] input: Value,
558 #[allow(unused)] ctx: &OpRunContext,
559 ) -> Result<Value, OpError> {
560 Err(OpError::InvalidValue("In-place execution not supported"))
561 }
562
563 fn prepack_inputs(&self) -> SmallVec<[usize; 1]> {
565 SmallVec::new()
566 }
567
568 fn prepack(
573 &self,
574 #[allow(unused)] index: usize,
575 #[allow(unused)] input: ValueView,
576 ) -> Option<PrepackedInput> {
577 None
578 }
579
580 fn as_subgraph_op(&self) -> Option<&dyn SubgraphOperator> {
583 None
584 }
585}
586
587impl dyn Operator {
588 pub fn downcast_ref<T: Any>(&self) -> Option<&T> {
590 (self as &dyn Any).downcast_ref()
591 }
592}
593
594pub trait SubgraphOperator: Operator {
596 fn subgraphs(&self) -> SmallVec<[&Graph; 2]> {
598 SmallVec::new()
599 }
600
601 fn run_subgraph<'a>(
606 &'a self,
607 ctx: &OpRunContext,
608 #[allow(unused)] captures: CaptureEnv,
609 #[allow(unused)] weight_cache: Option<&[WeightCache]>,
610 #[allow(unused)] profiler: Option<&mut Profiler<'a>>,
611 #[allow(unused)] run_opts: Option<RunOptions>,
612 ) -> Result<OutputList, RunError>;
613}
614
615pub trait OperatorExt: Operator {
617 fn run_simple<'a, I: Into<InputList<'a>>, O: TryFrom<Value>>(
623 &self,
624 inputs: I,
625 ) -> Result<O, OpError>
626 where
627 OpError: From<<O as TryFrom<Value>>::Error>,
628 {
629 let pool = BufferPool::new();
630 let inputs = inputs.into();
631 let ctx = OpRunContext::new(&pool, &inputs);
632 let mut outputs = self.run(&ctx)?;
633 Ok(outputs.remove(0).try_into()?)
634 }
635
636 fn run_simple_in_place<'a, M: Into<Value>, I: Into<InputList<'a>>, O: TryFrom<Value>>(
638 &self,
639 mut_input: M,
640 inputs: I,
641 ) -> Result<O, OpError>
642 where
643 OpError: From<<O as TryFrom<Value>>::Error>,
644 {
645 let pool = BufferPool::new();
646 let inputs = inputs.into();
647 let ctx = OpRunContext::new(&pool, &inputs);
648 let output = self.run_in_place(mut_input.into(), &ctx)?;
649 let typed_output = output.try_into()?;
650 Ok(typed_output)
651 }
652}
653
654impl<O: ?Sized + Operator> OperatorExt for O {}
655
656#[derive(Clone)]
666pub struct InputList<'a> {
667 inputs: Cow<'a, [Option<ValueView<'a>>]>,
668
669 get_prepacked: Option<&'a dyn Fn(usize) -> Option<&'a PrepackedInput>>,
672
673 first_input_omitted: bool,
677}
678
679impl<'a> InputList<'a> {
680 pub fn new() -> InputList<'a> {
682 InputList {
683 inputs: Cow::Owned(vec![]),
684 get_prepacked: None,
685 first_input_omitted: false,
686 }
687 }
688
689 pub fn with_first_input_omitted(mut self, offset: bool) -> Self {
695 self.first_input_omitted = offset;
696 self
697 }
698
699 pub fn len(&self) -> usize {
700 self.inputs.len()
701 }
702
703 pub fn is_empty(&self) -> bool {
704 self.inputs.is_empty()
705 }
706
707 pub fn push<I: Into<ValueView<'a>>>(&mut self, inp: I) {
711 self.inputs.to_mut().push(Some(inp.into()))
712 }
713
714 pub fn push_optional<I: Into<ValueView<'a>>>(&mut self, inp: Option<I>) {
718 self.inputs.to_mut().push(inp.map(|inp| inp.into()))
719 }
720
721 pub fn from(inputs: &[ValueView<'a>]) -> InputList<'a> {
726 InputList {
727 inputs: inputs.iter().cloned().map(Some).collect(),
728 get_prepacked: None,
729 first_input_omitted: false,
730 }
731 }
732
733 pub fn from_optional(inputs: &'a [Option<ValueView<'a>>]) -> InputList<'a> {
737 InputList {
738 inputs: Cow::Borrowed(inputs),
739 get_prepacked: None,
740 first_input_omitted: false,
741 }
742 }
743
744 pub fn with_prepacked(
747 mut self,
748 lookup: &'a dyn Fn(usize) -> Option<&'a PrepackedInput>,
749 ) -> Self {
750 self.get_prepacked = Some(lookup);
751 self
752 }
753
754 pub fn get(&self, index: usize) -> Option<ValueView<'a>> {
756 self.inputs.get(index).cloned().flatten()
757 }
758
759 pub fn get_prepacked(&self, index: usize) -> Option<&'a PrepackedInput> {
761 self.get_prepacked.and_then(|gp| gp(index))
762 }
763
764 pub fn get_mut(&mut self, index: usize) -> Option<&mut ValueView<'a>> {
768 self.inputs.to_mut().get_mut(index)?.as_mut()
769 }
770
771 pub fn get_as<T>(&self, index: usize) -> Result<Option<T>, OpError>
773 where
774 T: TryFrom<ValueView<'a>, Error = CastError>,
775 {
776 self.get(index)
777 .map(|input| {
778 input.try_into().map_err(|error| OpError::InputCastFailed {
779 index: self.to_real_index(index),
780 error,
781 })
782 })
783 .transpose()
784 }
785
786 pub fn require(&self, index: usize) -> Result<ValueView<'a>, OpError> {
788 self.get(index).ok_or(OpError::MissingInputs)
789 }
790
791 pub fn require_as<T>(&self, index: usize) -> Result<T, OpError>
793 where
794 T: TryFrom<ValueView<'a>, Error = CastError>,
795 {
796 self.require(index).and_then(|input| {
797 input.try_into().map_err(|error| OpError::InputCastFailed {
798 index: self.to_real_index(index),
799 error,
800 })
801 })
802 }
803
804 pub fn iter<'b>(&'b self) -> impl Iterator<Item = Option<ValueView<'a>>> + 'b {
808 self.inputs.iter().cloned()
809 }
810
811 fn to_real_index(&self, index: usize) -> usize {
814 if self.first_input_omitted {
815 index + 1
816 } else {
817 index
818 }
819 }
820}
821
822impl Default for InputList<'_> {
823 fn default() -> Self {
824 Self::new()
825 }
826}
827
828impl<'a, I: Into<ValueView<'a>>> From<I> for InputList<'a> {
829 fn from(val: I) -> InputList<'a> {
830 InputList::from(&[val.into()])
831 }
832}
833
834impl<'a> From<()> for InputList<'a> {
835 fn from(_: ()) -> InputList<'a> {
836 Self::default()
837 }
838}
839
840impl<'a, I1: Into<ValueView<'a>>> From<(I1,)> for InputList<'a> {
841 fn from((a,): (I1,)) -> InputList<'a> {
842 InputList::from(&[a.into()])
843 }
844}
845
846impl<'a, I1: Into<ValueView<'a>>, I2: Into<ValueView<'a>>> From<(I1, I2)> for InputList<'a> {
847 fn from((a, b): (I1, I2)) -> InputList<'a> {
848 InputList::from(&[a.into(), b.into()])
849 }
850}
851
852impl<'a, I1: Into<ValueView<'a>>, I2: Into<ValueView<'a>>, I3: Into<ValueView<'a>>>
853 From<(I1, I2, I3)> for InputList<'a>
854{
855 fn from((a, b, c): (I1, I2, I3)) -> InputList<'a> {
856 InputList::from(&[a.into(), b.into(), c.into()])
857 }
858}
859
860impl<'a> Extend<ValueView<'a>> for InputList<'a> {
861 fn extend<T>(&mut self, iter: T)
862 where
863 T: IntoIterator<Item = ValueView<'a>>,
864 {
865 for item in iter {
866 self.push(item);
867 }
868 }
869}
870
871impl<'a> Extend<Option<ValueView<'a>>> for InputList<'a> {
872 fn extend<T>(&mut self, iter: T)
873 where
874 T: IntoIterator<Item = Option<ValueView<'a>>>,
875 {
876 for item in iter {
877 self.push_optional(item);
878 }
879 }
880}
881
882impl<'a, A> FromIterator<A> for InputList<'a>
883where
884 InputList<'a>: Extend<A>,
885{
886 fn from_iter<T>(iter: T) -> Self
887 where
888 T: IntoIterator<Item = A>,
889 {
890 let mut list = InputList::new();
891 list.extend(iter);
892 list
893 }
894}
895
896fn resolve_index(len: usize, index: isize) -> Option<usize> {
899 let len = len as isize;
900 if index < -len || index >= len {
901 return None;
902 }
903
904 if index >= 0 {
905 Some(index as usize)
906 } else {
907 Some((len + index) as usize)
908 }
909}
910
911fn resolve_axis(ndim: usize, axis: isize) -> Result<usize, OpError> {
916 resolve_index(ndim, axis).ok_or(OpError::InvalidValue("Axis is invalid"))
917}
918
919pub fn resolve_axes<'a, I: ExactSizeIterator<Item = &'a i32>>(
924 ndim: usize,
925 axes: I,
926) -> Result<SmallVec<[usize; 4]>, OpError> {
927 let mut resolved_axes = SmallVec::with_capacity(axes.len());
928 for axis in axes {
929 let resolved = resolve_axis(ndim, *axis as isize)?;
930 resolved_axes.push(resolved);
931 }
932 Ok(resolved_axes)
933}
934
935macro_rules! map_value_view {
947 ($input:expr, $typed_input:ident, $block:tt) => {
948 match $input {
949 ValueView::FloatTensor($typed_input) => $block,
950 ValueView::Int32Tensor($typed_input) => $block,
951 ValueView::UInt8Tensor($typed_input) => $block,
952 ValueView::Int8Tensor($typed_input) => $block,
953 ValueView::Sequence(_) => Err(OpError::UnsupportedType)
954 }
955 };
956
957 ($input:expr, $typed_input:ident, [$($variant:ident),+], $block:tt) => {
958 match $input {
959 $(ValueView::$variant($typed_input) => $block),+,
960 _ => {
961 return Err(OpError::UnsupportedType);
962 }
963 }
964 };
965}
966
967use map_value_view;
968
969macro_rules! map_dtype {
974 ($dtype:expr, $type:ident, $block:tt) => {{
975 use $crate::ops::DataType;
976
977 match $dtype {
978 DataType::Int32 => {
979 type $type = i32;
980 $block
981 }
982 DataType::Float => {
983 type $type = f32;
984 $block
985 }
986 DataType::UInt8 => {
987 type $type = u8;
988 $block
989 }
990 DataType::Int8 => {
991 type $type = i8;
992 $block
993 }
994 }
995 }};
996}
997
998use map_dtype;
999
1000macro_rules! map_value {
1012 ($input:expr, $typed_input:ident, $block:tt) => {
1013 match $input {
1014 #[allow(unused_mut)]
1015 Value::FloatTensor(mut $typed_input) => $block,
1016 #[allow(unused_mut)]
1017 Value::Int32Tensor(mut $typed_input) => $block,
1018 #[allow(unused_mut)]
1019 Value::UInt8Tensor(mut $typed_input) => $block,
1020 #[allow(unused_mut)]
1021 Value::Int8Tensor(mut $typed_input) => $block,
1022 Value::Sequence(_) => Err(OpError::UnsupportedType),
1023 }
1024 };
1025
1026 ($input:expr, $typed_input:ident, [$($variant:ident),+], $block:tt) => {
1027 match $input {
1028 $(
1029 #[allow(unused_mut)]
1030 Value::$variant(mut $typed_input) => $block
1031 ),+,
1032 _ => {
1033 return Err(OpError::UnsupportedType);
1034 }
1035 }
1036 };
1037}
1038
1039use map_value;
1040
1041macro_rules! check_value {
1047 ($condition:expr, $err_variant:ident, $err_msg:expr) => {
1048 if !$condition {
1049 return Err(OpError::$err_variant($err_msg));
1050 }
1051 };
1052}
1053
1054use check_value;
1055
1056#[cfg(test)]
1057mod tests {
1058 use rten_tensor::prelude::*;
1059 use rten_tensor::test_util::{ExpectEqualError, expect_equal_with_tolerance};
1060 use rten_tensor::{NdTensor, Tensor, TensorView};
1061
1062 use super::Operator;
1063 use crate::buffer_pool::BufferPool;
1064 use crate::ops::{Add, InputList, OpError, Sub};
1065
1066 pub fn new_pool() -> BufferPool {
1071 BufferPool::new()
1072 }
1073
1074 pub fn expect_eq_1e4<V: AsView<Elem = f32>>(
1080 result: &V,
1081 expected: &V,
1082 ) -> Result<(), ExpectEqualError> {
1083 expect_equal_with_tolerance(result, expected, 1e-4, 0.)
1084 }
1085
1086 pub trait IntoNDim<const N: usize> {
1088 type Output;
1090
1091 fn into_ndim(self) -> Self::Output;
1096 }
1097
1098 impl<T: Clone, const M: usize, const N: usize> IntoNDim<N> for NdTensor<T, M> {
1099 type Output = NdTensor<T, N>;
1100
1101 fn into_ndim(self) -> Self::Output {
1102 assert!(N >= M);
1103 let new_dims = N - M;
1104 let shape = self.shape();
1105 let new_shape =
1106 std::array::from_fn(|d| if d < new_dims { 1 } else { shape[d - new_dims] });
1107 self.into_shape(new_shape)
1108 }
1109 }
1110
1111 #[test]
1112 fn test_input_list_first_input_omitted() {
1113 let tensor = Tensor::<f32>::zeros(&[2, 2]);
1114
1115 let inputs = InputList::from(&[tensor.view().into()]).with_first_input_omitted(false);
1116 let err = inputs.require_as::<TensorView<i32>>(0).err().unwrap();
1117 assert!(matches!(err, OpError::InputCastFailed { index: 0, .. }));
1118
1119 let inputs = InputList::from(&[tensor.view().into()]).with_first_input_omitted(true);
1120 let err = inputs.require_as::<TensorView<i32>>(0).err().unwrap();
1121 assert!(matches!(err, OpError::InputCastFailed { index: 1, .. }));
1122 }
1123
1124 #[test]
1125 fn test_downcast_operator() {
1126 let add_op = Add {};
1127 let sub_op = Sub {};
1128
1129 let add_op_dyn: &dyn Operator = &add_op;
1130 let sub_op_dyn: &dyn Operator = &sub_op;
1131
1132 assert!(add_op_dyn.downcast_ref::<Add>().is_some());
1133 assert!(sub_op_dyn.downcast_ref::<Sub>().is_some());
1134 }
1135}