rten/ops/
mod.rs

1//! The `ops` module exposes the various operators available for machine-learning
2//! models.
3//!
4//! Most operators correspond to an [ONNX
5//! Operator](https://onnx.ai/onnx/operators/) of the same name, though RTen
6//! does not support all ONNX operators, data types or attributes.
7//!
8//! Operators are primarily invoked by RTen as part of executing a
9//! [Model](crate::Model), however they are also exposed as standalone
10//! functions and tensor methods for use in code that pre-processes model
11//! inputs and post-processes model outputs. Some standalone operator functions
12//! come into two flavors, one which operates in-place on an existing tensor,
13//! and one which takes a view as input and returns a new tensor as output.
14
15use std::any::Any;
16use std::borrow::Cow;
17use std::convert::Infallible;
18use std::error::Error;
19use std::fmt;
20use std::fmt::{Debug, Display};
21
22use smallvec::SmallVec;
23
24use rten_gemm::PackedBMatrix;
25use rten_tensor::errors::DimensionError;
26use rten_tensor::{MutLayout, Storage, TensorBase};
27
28use crate::buffer_pool::BufferPool;
29use crate::graph::{CaptureEnv, Graph, RunError, RunOptions};
30use crate::timing::Profiler;
31use crate::value::{CastError, DataType, DataTypeOf, Value, ValueOrView, ValueView};
32use crate::weight_cache::WeightCache;
33
34mod attention;
35mod binary_elementwise;
36mod concat;
37mod control_flow;
38mod conv;
39mod conv_transpose;
40mod convert;
41mod einsum;
42mod gather;
43mod generate;
44mod grid_sample;
45mod identity;
46mod layout;
47mod matmul;
48mod non_max_suppression;
49mod norm;
50mod pad;
51mod pooling;
52mod quantize;
53
54#[cfg(feature = "fft")]
55mod fft;
56
57#[cfg(feature = "random")]
58mod random;
59
60mod reduce;
61mod resize;
62mod rnn;
63mod sequence;
64mod slice;
65mod split;
66mod trilu;
67mod unary_elementwise;
68mod variadic_elementwise;
69
70// Fused operations
71pub(crate) mod transform_inputs;
72
73pub use attention::AddSoftmax;
74pub use binary_elementwise::{
75    Add, And, Div, DivMode, Equal, Greater, GreaterOrEqual, Less, LessOrEqual, Mod, Mul, Or, Pow,
76    Sub, Where, Xor, add, and, div, equal, greater, greater_or_equal, less, less_or_equal, mod_op,
77    mul, or, pow, sub, where_op, xor,
78};
79pub use concat::{Concat, Tile, concat, tile};
80pub use control_flow::{If, Loop};
81pub use conv::{Conv, ConvInteger, conv, conv_integer};
82pub use conv_transpose::{ConvTranspose, conv_transpose};
83pub use convert::{Cast, CastLike};
84pub use einsum::{Einsum, einsum};
85pub use gather::{
86    Gather, GatherElements, GatherND, ScatterElements, ScatterND, ScatterReduction, gather,
87    gather_elements, gather_nd, scatter_elements, scatter_nd,
88};
89pub use generate::{ConstantOfShape, EyeLike, OneHot, Range, constant_of_shape, onehot, range};
90pub use grid_sample::GridSample;
91pub use identity::Identity;
92pub use layout::{
93    DepthToSpace, DepthToSpaceMode, Expand, Flatten, Reshape, Shape, Size, Squeeze, Transpose,
94    Unsqueeze, depth_to_space, expand, flatten, reshape, squeeze,
95};
96pub use matmul::{FusedMatMul, Gemm, MatMul, MatMulInteger, MatMulIntegerToFloat, gemm_op, matmul};
97pub use non_max_suppression::{BoxOrder, NonMaxSuppression, non_max_suppression};
98pub use norm::{
99    BatchNormalization, InstanceNormalization, LayerNormalization, LogSoftmax, RmsNormalization,
100    Softmax, batch_norm, instance_normalization, layer_normalization, log_softmax,
101    rms_normalization, softmax,
102};
103pub use pad::{Pad, PadMode, pad};
104pub use pooling::{
105    AveragePool, GlobalAveragePool, MaxPool, average_pool, global_average_pool, max_pool,
106};
107pub use quantize::{
108    DequantizeLinear, DynamicQuantizeLinear, QuantizeLinear, dequantize_linear,
109    dynamic_quantize_linear, quantize_linear,
110};
111
112#[cfg(feature = "fft")]
113pub use fft::{STFT, stft};
114
115#[cfg(feature = "random")]
116pub use random::{Dropout, RandomNormal, RandomNormalLike, RandomUniform, RandomUniformLike};
117
118pub use reduce::{
119    ArgMax, ArgMin, CumSum, NonZero, ReduceL2, ReduceMax, ReduceMean, ReduceMin, ReduceProd,
120    ReduceSum, ReduceSumSquare, TopK, arg_max, arg_min, cum_sum, nonzero, reduce_l2, reduce_max,
121    reduce_mean, reduce_min, reduce_prod, reduce_sum, reduce_sum_square, topk,
122};
123pub use resize::{
124    CoordTransformMode, NearestMode, Resize, ResizeMode, ResizeTarget, resize, resize_image,
125};
126pub use rnn::{Direction, GRU, LSTM, gru, lstm};
127pub use sequence::{
128    ConcatFromSequence, SequenceAt, SequenceConstruct, SequenceEmpty, SequenceErase,
129    SequenceInsert, SequenceLength, SplitToSequence,
130};
131pub use slice::{Slice, slice};
132pub use split::{Split, split};
133pub use trilu::{Trilu, trilu};
134pub use unary_elementwise::{
135    Abs, Acos, Asin, Atan, Ceil, Clip, Cos, Elu, Erf, Exp, Floor, Gelu, HardSigmoid, HardSwish,
136    IsInf, IsNaN, LeakyRelu, Log, Neg, Not, PRelu, Reciprocal, Relu, Round, Sigmoid, Sign, Silu,
137    Sin, Softplus, Sqrt, Swish, Tan, Tanh,
138};
139pub use variadic_elementwise::{Max, Mean, Min, Sum, max, mean, min, sum};
140
141mod operators;
142pub use operators::{FloatOperators, Operators};
143
144#[derive(Clone, Debug)]
145pub enum Padding {
146    /// Apply enough padding such that the output and input have the same size.
147    ///
148    /// If the required amount of padding along each dimension is even, it is
149    /// divided equally between the start and the end. If it is odd, one more
150    /// unit is added on the end than the start. This matches the ONNX spec
151    /// for the "SAME_UPPER" value for the `auto_pad` attribute.
152    Same,
153
154    /// Apply a given amount of padding to each side of the input. Paddings
155    /// are specified in the order `[start, end]` for 1D padding,
156    /// `[top, left, bottom, right]` for 2D and so on.
157    Fixed(SmallVec<[usize; 4]>),
158}
159
160impl Padding {
161    /// Return fixed zero padding for an N-dimensional shape.
162    pub fn zero<const N: usize>() -> Padding {
163        Padding::Fixed(SmallVec::from_elem(0, N * 2))
164    }
165
166    /// Expand padding for a 1D operation to 2D.
167    pub fn expand_1d_to_2d(&self) -> Result<Padding, OpError> {
168        match self {
169            Padding::Same => Ok(Padding::Same),
170            Padding::Fixed(pads) => match pads.as_slice() {
171                &[pad_start, pad_end] => Ok([0, pad_start, 0, pad_end].into()),
172                _ => Err(OpError::InvalidValue("expected 2 pad values")),
173            },
174        }
175    }
176}
177
178/// Construct a [`Padding::Fixed`] from a slice of paddings for each size.
179impl<S: AsRef<[usize]>> From<S> for Padding {
180    fn from(val: S) -> Padding {
181        Padding::Fixed(val.as_ref().into())
182    }
183}
184
185/// An operator input which has been pre-packed for more efficient use during
186/// inference.
187pub enum PrepackedInput {
188    /// Prepacked RHS / B input for matrix multiplication with f32 weights.
189    FloatBMatrix(PackedBMatrix<f32>),
190
191    /// Prepacked RHS / B input for matrix multiplication with i8 weights.
192    Int8BMatrix(PackedBMatrix<i8>),
193}
194
195impl PrepackedInput {
196    fn dtype(&self) -> DataType {
197        match self {
198            Self::FloatBMatrix(_) => DataType::Float,
199            Self::Int8BMatrix(_) => DataType::Int8,
200        }
201    }
202}
203
204macro_rules! impl_prepacked_input_conversions {
205    ($type:ty, $variant:ident) => {
206        impl From<PackedBMatrix<$type>> for PrepackedInput {
207            fn from(value: PackedBMatrix<$type>) -> Self {
208                PrepackedInput::$variant(value)
209            }
210        }
211
212        impl<'a> TryFrom<&'a PrepackedInput> for &'a PackedBMatrix<$type> {
213            type Error = CastError;
214
215            fn try_from(ppi: &'a PrepackedInput) -> Result<Self, Self::Error> {
216                match ppi {
217                    PrepackedInput::$variant(packed) => Ok(packed),
218                    _ => Err(CastError::WrongType {
219                        actual: ppi.dtype(),
220                        expected: <$type as DataTypeOf>::dtype_of(),
221                    }),
222                }
223            }
224        }
225    };
226}
227impl_prepacked_input_conversions!(f32, FloatBMatrix);
228impl_prepacked_input_conversions!(i8, Int8BMatrix);
229
230#[deprecated = "renamed to `ValueOrView`"]
231pub type InputOrOutput<'a> = ValueOrView<'a>;
232
233#[deprecated = "renamed to `ValueView`"]
234pub type Input<'a> = ValueView<'a>;
235
236#[deprecated = "renamed to `Value`"]
237pub type Output = Value;
238
239/// Trait for values that can be converted into the result type used by
240/// [`Operator::run`].
241pub trait IntoOpResult {
242    fn into_op_result(self) -> Result<OutputList, OpError>;
243}
244
245impl IntoOpResult for Result<Value, OpError> {
246    fn into_op_result(self) -> Result<OutputList, OpError> {
247        self.map(|out| [out].into())
248    }
249}
250
251impl IntoOpResult for Value {
252    fn into_op_result(self) -> Result<OutputList, OpError> {
253        Ok([self].into())
254    }
255}
256
257impl<S: Storage, L: MutLayout> IntoOpResult for TensorBase<S, L>
258where
259    Value: From<TensorBase<S, L>>,
260{
261    fn into_op_result(self) -> Result<OutputList, OpError> {
262        let output: Value = self.into();
263        Ok([output].into())
264    }
265}
266
267impl<S: Storage, L: MutLayout> IntoOpResult for Result<TensorBase<S, L>, OpError>
268where
269    Value: From<TensorBase<S, L>>,
270{
271    fn into_op_result(self) -> Result<OutputList, OpError> {
272        self.map(|tensor| [tensor.into()].into())
273    }
274}
275
276impl<T> IntoOpResult for Result<Vec<T>, OpError>
277where
278    Value: From<T>,
279{
280    fn into_op_result(self) -> Result<OutputList, OpError> {
281        self.map(|tensors| tensors.into_iter().map(|t| t.into()).collect())
282    }
283}
284
285/// Possible reasons why an operator may fail on a given input.
286#[derive(Eq, PartialEq, Debug)]
287pub enum OpError {
288    /// Casting a tensor to an expected type or rank failed.
289    CastFailed(CastError),
290
291    /// Casting an input to an expected type or rank failed.
292    InputCastFailed { index: usize, error: CastError },
293
294    /// A tensor has an unsupported type.
295    UnsupportedType,
296
297    /// Input tensor shapes are not compatible with each other or operator
298    /// attributes.
299    IncompatibleInputShapes(&'static str),
300
301    /// The number of inputs was less than the required number.
302    MissingInputs,
303
304    /// An input has a value that is incorrect.
305    InvalidValue(&'static str),
306
307    /// An input or attribute has a value that is valid, but not currently supported.
308    UnsupportedValue(&'static str),
309}
310
311impl OpError {
312    /// Associate this error with a given operator input.
313    fn with_input_index(self, index: usize) -> OpError {
314        match self {
315            Self::CastFailed(error) => OpError::InputCastFailed { index, error },
316            Self::InputCastFailed { error, .. } => OpError::InputCastFailed { index, error },
317            other => other,
318        }
319    }
320}
321
322impl From<DimensionError> for OpError {
323    fn from(val: DimensionError) -> OpError {
324        OpError::CastFailed(val.into())
325    }
326}
327
328impl From<CastError> for OpError {
329    fn from(val: CastError) -> OpError {
330        OpError::CastFailed(val)
331    }
332}
333
334impl From<Infallible> for OpError {
335    fn from(x: Infallible) -> OpError {
336        match x {}
337    }
338}
339
340impl Display for OpError {
341    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
342        match self {
343            OpError::CastFailed(err) => write!(f, "{}", err),
344            OpError::InputCastFailed { index, error } => {
345                write!(f, "conversion error for input {}: {}", index, error)
346            }
347            OpError::IncompatibleInputShapes(details) => {
348                write!(f, "incompatible input shapes: {}", details)
349            }
350            OpError::MissingInputs => write!(f, "required inputs were missing"),
351            OpError::InvalidValue(details) => {
352                write!(f, "input or attribute has invalid value: {}", details)
353            }
354            OpError::UnsupportedValue(details) => {
355                write!(f, "unsupported input or attribute value: {}", details)
356            }
357            OpError::UnsupportedType => {
358                write!(f, "unsupported input type")
359            }
360        }
361    }
362}
363
364impl Error for OpError {}
365
366/// Convert a tensor with dynamic dimension count to a view with a static
367/// dimension count.
368///
369/// If the conversion fails an `OpError::InvalidValue` error will be returned
370/// with a message that includes the name of the tensor and, optionally, the
371/// names of the expected dimensions (eg. "NCHW").
372macro_rules! static_dims {
373    ($tensor:ident, $ndim:literal, $dim_names:literal) => {{
374        use rten_tensor::prelude::*;
375
376        if $tensor.ndim() != $ndim {
377            Err(OpError::InvalidValue(concat!(
378                stringify!($tensor),
379                " must have ",
380                stringify!($ndim),
381                " dims (",
382                $dim_names,
383                ")"
384            )))
385        } else {
386            Ok($tensor.nd_view::<$ndim>())
387        }
388    }};
389
390    ($tensor:ident, $ndim:literal) => {{
391        use rten_tensor::prelude::*;
392
393        if $tensor.ndim() != $ndim {
394            Err(OpError::InvalidValue(concat!(
395                stringify!($tensor),
396                " must have ",
397                stringify!($ndim),
398                " dims"
399            )))
400        } else {
401            Ok($tensor.nd_view::<$ndim>())
402        }
403    }};
404
405    ($tensor:ident?, $ndim: expr) => {
406        if let Some($tensor) = $tensor.as_ref() {
407            Some(static_dims!($tensor, $ndim))
408        } else {
409            None
410        }
411    };
412}
413
414pub(crate) use static_dims;
415
416/// Context passed to [`Operator::run`] containing the information needed for
417/// the operator to execute.
418pub struct OpRunContext<'a, 'i> {
419    pool: &'a BufferPool,
420    inputs: &'a InputList<'i>,
421    n_outputs: Option<u32>,
422    name: Option<&'a str>,
423}
424
425impl<'a, 'i> OpRunContext<'a, 'i> {
426    pub fn new(pool: &'a BufferPool, inputs: &'a InputList<'i>) -> Self {
427        OpRunContext {
428            pool,
429            inputs,
430            n_outputs: None,
431            name: None,
432        }
433    }
434
435    /// Construct a new context with the same properties but different inputs.
436    ///
437    /// This is useful when one operator wants to delegate to another.
438    pub fn with_new_inputs<'b, 'il>(&self, inputs: &'b InputList<'il>) -> OpRunContext<'b, 'il>
439    where
440        'a: 'b,
441    {
442        OpRunContext { inputs, ..*self }
443    }
444
445    /// The pool which should be used to allocate large buffers.
446    pub fn pool(&self) -> &BufferPool {
447        self.pool
448    }
449
450    /// Inputs to the operator execution.
451    ///
452    /// For in-place execution via [`Operator::run_in_place`] this contains
453    /// the non in-place inputs.
454    pub fn inputs(&self) -> &InputList<'i> {
455        self.inputs
456    }
457
458    /// Set the requested number of outputs.
459    ///
460    /// This can be used to skip generating outputs that are unused, or in
461    /// the rare cases that the output count cannot be determined from the
462    /// operator's inputs and attributes alone.
463    pub fn set_num_outputs(&mut self, n: u32) {
464        self.n_outputs = Some(n);
465    }
466
467    /// Return the number of requested outputs or `None` if this has not been
468    /// specified.
469    pub fn num_outputs(&self) -> Option<u32> {
470        self.n_outputs
471    }
472
473    /// Set the name of the current node in the graph.
474    pub fn set_name(&mut self, name: Option<&'a str>) {
475        self.name = name;
476    }
477
478    /// Return the name of the current node in the graph.
479    pub fn name(&self) -> Option<&str> {
480        self.name
481    }
482}
483
484/// Outputs from an operator.
485///
486/// This avoids allocations in the common case where an operator produces
487/// exactly one output.
488pub type OutputList = SmallVec<[Value; 1]>;
489
490/// An Operator performs a computation step when executing a data flow graph.
491///
492/// Operators take zero or more dynamic input values, plus a set of static
493/// attributes and produce one or more output values.
494///
495/// Operators are usually named after the ONNX operator that they implement.
496/// See <https://onnx.ai/onnx/operators/>.
497pub trait Operator: Any + Debug {
498    /// Return a display name for the operator.
499    fn name(&self) -> &str;
500
501    /// Execute the operator.
502    ///
503    /// `ctx` provides access to operator inputs and the [`BufferPool`] from
504    /// which the output and temporary buffers should be allocated.
505    ///
506    /// For operators which have subgraphs (see
507    /// [`as_subgraph_op`](Operator::as_subgraph_op)), the
508    /// [`SubgraphOperator::run_subgraph`] method should be used instead.
509    fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError>;
510
511    /// Return true if this operator supports in-place execution via
512    /// `run_in_place`.
513    ///
514    /// In-place execution returns results by modifying an existing tensor
515    /// instead of allocating a new one. Reducing memory allocations can
516    /// significantly speed up graph runs.
517    fn can_run_in_place(&self) -> bool {
518        false
519    }
520
521    /// Return true if this operator is commutative, meaning that its inputs
522    /// can be re-ordered without affecting the result.
523    ///
524    /// If true, the graph executor may swap inputs before calling the
525    /// [`Operator::run_in_place`] implementation.
526    fn is_commutative(&self) -> bool {
527        false
528    }
529
530    /// Return true if this operator's outputs depend only on its inputs.
531    ///
532    /// The default implementation returns true, since most operators are
533    /// deterministic. Operators such as random number generators however are
534    /// not.
535    ///
536    /// The definition of _deterministic_ used here excludes minor differences
537    /// due to eg. the order in which results from parallel sub-problems are
538    /// accumulated. It also does not guarantee exact consistency across devices.
539    fn is_deterministic(&self) -> bool {
540        true
541    }
542
543    /// Execute this operator in-place on an existing tensor.
544    ///
545    /// This may only be called if `can_run_in_place` returns true.
546    ///
547    /// `input` is the first input, which the implementation may modify and
548    /// return as the output. `ctx.inputs()` contains the remaining inputs.
549    ///
550    /// Operators may fall back to allocating a new output if some property of
551    /// the input data or shapes means in-place operation is not possible. In
552    /// this case they should return the input buffer to the pool, and allocate
553    /// the new output buffer from it. The pool should also be used for any
554    /// temporary buffers created during execution.
555    fn run_in_place(
556        &self,
557        #[allow(unused)] input: Value,
558        #[allow(unused)] ctx: &OpRunContext,
559    ) -> Result<Value, OpError> {
560        Err(OpError::InvalidValue("In-place execution not supported"))
561    }
562
563    /// Return the IDs of inputs which can be pre-packed using [`prepack`](Operator::prepack).
564    fn prepack_inputs(&self) -> SmallVec<[usize; 1]> {
565        SmallVec::new()
566    }
567
568    /// Pre-pack an input for more efficient inference later.
569    ///
570    /// `index` specifies the input ID and should be one of the inputs returned
571    /// by [`prepack_inputs`](Operator::prepack_inputs).
572    fn prepack(
573        &self,
574        #[allow(unused)] index: usize,
575        #[allow(unused)] input: ValueView,
576    ) -> Option<PrepackedInput> {
577        None
578    }
579
580    /// Return the [`SubgraphOperator`] implementation for this operator, if
581    /// this operator has subgraphs.
582    fn as_subgraph_op(&self) -> Option<&dyn SubgraphOperator> {
583        None
584    }
585}
586
587impl dyn Operator {
588    /// Downcast this operator to a concrete type.
589    pub fn downcast_ref<T: Any>(&self) -> Option<&T> {
590        (self as &dyn Any).downcast_ref()
591    }
592}
593
594/// Trait for operators which contain subgraphs, such as `If`, `Loop` etc.
595pub trait SubgraphOperator: Operator {
596    /// Return a list of subgraphs used by this operator.
597    fn subgraphs(&self) -> SmallVec<[&Graph; 2]> {
598        SmallVec::new()
599    }
600
601    /// Execute the operator with the given inputs and captured values.
602    ///
603    /// This should be used instead of [`Operator::run`] for operators that
604    /// implement this trait.
605    fn run_subgraph<'a>(
606        &'a self,
607        ctx: &OpRunContext,
608        #[allow(unused)] captures: CaptureEnv,
609        #[allow(unused)] weight_cache: Option<&[WeightCache]>,
610        #[allow(unused)] profiler: Option<&mut Profiler<'a>>,
611        #[allow(unused)] run_opts: Option<RunOptions>,
612    ) -> Result<OutputList, RunError>;
613}
614
615/// Convenience methods that make it easier to run operators in tests.
616pub trait OperatorExt: Operator {
617    /// Run an operator and extract the first output as a tensor with a given
618    /// type.
619    ///
620    /// `inputs` is a tuple of tensor references or other values that can be
621    /// converted to [`ValueView`].
622    fn run_simple<'a, I: Into<InputList<'a>>, O: TryFrom<Value>>(
623        &self,
624        inputs: I,
625    ) -> Result<O, OpError>
626    where
627        OpError: From<<O as TryFrom<Value>>::Error>,
628    {
629        let pool = BufferPool::new();
630        let inputs = inputs.into();
631        let ctx = OpRunContext::new(&pool, &inputs);
632        let mut outputs = self.run(&ctx)?;
633        Ok(outputs.remove(0).try_into()?)
634    }
635
636    /// Run an operator with a mutable input and extract the first output.
637    fn run_simple_in_place<'a, M: Into<Value>, I: Into<InputList<'a>>, O: TryFrom<Value>>(
638        &self,
639        mut_input: M,
640        inputs: I,
641    ) -> Result<O, OpError>
642    where
643        OpError: From<<O as TryFrom<Value>>::Error>,
644    {
645        let pool = BufferPool::new();
646        let inputs = inputs.into();
647        let ctx = OpRunContext::new(&pool, &inputs);
648        let output = self.run_in_place(mut_input.into(), &ctx)?;
649        let typed_output = output.try_into()?;
650        Ok(typed_output)
651    }
652}
653
654impl<O: ?Sized + Operator> OperatorExt for O {}
655
656/// List of inputs for an operator evaluation.
657///
658/// This is an owned or borrowed collection of `Option<ValueView>`s with methods
659/// to conveniently extract inputs and produce appropriate errors if inputs are
660/// missing or of the wrong type.
661///
662/// An InputList can be constructed from tuples of `impl Into<ValueView>` types
663/// (eg. `TensorView`, `&Tensor`) via `Into`. It can also be created or
664/// extended from iterators of `ValueView`s or `Option<ValueView>`s.
665#[derive(Clone)]
666pub struct InputList<'a> {
667    inputs: Cow<'a, [Option<ValueView<'a>>]>,
668
669    /// Callback that retrieves the pre-packed copy of an input with a given
670    /// index.
671    get_prepacked: Option<&'a dyn Fn(usize) -> Option<&'a PrepackedInput>>,
672
673    /// True if the input list does not contain the first operator input because
674    /// it is being passed separately. In this case input indices are offset by
675    /// one (eg. `inputs.require(0)` will return the second input to the operator).
676    first_input_omitted: bool,
677}
678
679impl<'a> InputList<'a> {
680    /// Construct an empty input list.
681    pub fn new() -> InputList<'a> {
682        InputList {
683            inputs: Cow::Owned(vec![]),
684            get_prepacked: None,
685            first_input_omitted: false,
686        }
687    }
688
689    /// Mark this input list as not containing the first input to the operator.
690    ///
691    /// This is used together with [`Operator::run_in_place`] where the first
692    /// input is passed separately. When this flag is set the input index is
693    /// adjusted in errors to reflect the real index.
694    pub fn with_first_input_omitted(mut self, offset: bool) -> Self {
695        self.first_input_omitted = offset;
696        self
697    }
698
699    pub fn len(&self) -> usize {
700        self.inputs.len()
701    }
702
703    pub fn is_empty(&self) -> bool {
704        self.inputs.is_empty()
705    }
706
707    /// Append an input to the list.
708    ///
709    /// This will copy the existing inputs into a new owned vector.
710    pub fn push<I: Into<ValueView<'a>>>(&mut self, inp: I) {
711        self.inputs.to_mut().push(Some(inp.into()))
712    }
713
714    /// Append an optional input to the list.
715    ///
716    /// This will copy the existing inputs into a new owned vector.
717    pub fn push_optional<I: Into<ValueView<'a>>>(&mut self, inp: Option<I>) {
718        self.inputs.to_mut().push(inp.map(|inp| inp.into()))
719    }
720
721    /// Construct an input list from a slice of non-optional inputs.
722    ///
723    /// This copies the inputs into a new vector of `Option<ValueView>`s. Using
724    /// [`from_optional`](Self::from_optional) is more efficient.
725    pub fn from(inputs: &[ValueView<'a>]) -> InputList<'a> {
726        InputList {
727            inputs: inputs.iter().cloned().map(Some).collect(),
728            get_prepacked: None,
729            first_input_omitted: false,
730        }
731    }
732
733    /// Construct an input list from a slice of optional inputs.
734    ///
735    /// This is a cheap conversion that borrows `inputs`.
736    pub fn from_optional(inputs: &'a [Option<ValueView<'a>>]) -> InputList<'a> {
737        InputList {
738            inputs: Cow::Borrowed(inputs),
739            get_prepacked: None,
740            first_input_omitted: false,
741        }
742    }
743
744    /// Configure a callback that will get or create a pre-packed copy of the
745    /// input with a given index.
746    pub fn with_prepacked(
747        mut self,
748        lookup: &'a dyn Fn(usize) -> Option<&'a PrepackedInput>,
749    ) -> Self {
750        self.get_prepacked = Some(lookup);
751        self
752    }
753
754    /// Get an optional input.
755    pub fn get(&self, index: usize) -> Option<ValueView<'a>> {
756        self.inputs.get(index).cloned().flatten()
757    }
758
759    /// Get the pre-packed version of a weight input, if available.
760    pub fn get_prepacked(&self, index: usize) -> Option<&'a PrepackedInput> {
761        self.get_prepacked.and_then(|gp| gp(index))
762    }
763
764    /// Get a mutable reference to an input.
765    ///
766    /// This will convert the list into an owned list of inputs first.
767    pub fn get_mut(&mut self, index: usize) -> Option<&mut ValueView<'a>> {
768        self.inputs.to_mut().get_mut(index)?.as_mut()
769    }
770
771    /// Convert an optional input into a tensor or scalar.
772    pub fn get_as<T>(&self, index: usize) -> Result<Option<T>, OpError>
773    where
774        T: TryFrom<ValueView<'a>, Error = CastError>,
775    {
776        self.get(index)
777            .map(|input| {
778                input.try_into().map_err(|error| OpError::InputCastFailed {
779                    index: self.to_real_index(index),
780                    error,
781                })
782            })
783            .transpose()
784    }
785
786    /// Get a required operator input.
787    pub fn require(&self, index: usize) -> Result<ValueView<'a>, OpError> {
788        self.get(index).ok_or(OpError::MissingInputs)
789    }
790
791    /// Convert a required input into a tensor or scalar.
792    pub fn require_as<T>(&self, index: usize) -> Result<T, OpError>
793    where
794        T: TryFrom<ValueView<'a>, Error = CastError>,
795    {
796        self.require(index).and_then(|input| {
797            input.try_into().map_err(|error| OpError::InputCastFailed {
798                index: self.to_real_index(index),
799                error,
800            })
801        })
802    }
803
804    /// Return an iterator over provided inputs.
805    ///
806    /// Use [`Iterator::flatten`] to skip missing optional inputs.
807    pub fn iter<'b>(&'b self) -> impl Iterator<Item = Option<ValueView<'a>>> + 'b {
808        self.inputs.iter().cloned()
809    }
810
811    /// Map an index into this input list back to an index in the full
812    /// sequence of operator inputs.
813    fn to_real_index(&self, index: usize) -> usize {
814        if self.first_input_omitted {
815            index + 1
816        } else {
817            index
818        }
819    }
820}
821
822impl Default for InputList<'_> {
823    fn default() -> Self {
824        Self::new()
825    }
826}
827
828impl<'a, I: Into<ValueView<'a>>> From<I> for InputList<'a> {
829    fn from(val: I) -> InputList<'a> {
830        InputList::from(&[val.into()])
831    }
832}
833
834impl<'a> From<()> for InputList<'a> {
835    fn from(_: ()) -> InputList<'a> {
836        Self::default()
837    }
838}
839
840impl<'a, I1: Into<ValueView<'a>>> From<(I1,)> for InputList<'a> {
841    fn from((a,): (I1,)) -> InputList<'a> {
842        InputList::from(&[a.into()])
843    }
844}
845
846impl<'a, I1: Into<ValueView<'a>>, I2: Into<ValueView<'a>>> From<(I1, I2)> for InputList<'a> {
847    fn from((a, b): (I1, I2)) -> InputList<'a> {
848        InputList::from(&[a.into(), b.into()])
849    }
850}
851
852impl<'a, I1: Into<ValueView<'a>>, I2: Into<ValueView<'a>>, I3: Into<ValueView<'a>>>
853    From<(I1, I2, I3)> for InputList<'a>
854{
855    fn from((a, b, c): (I1, I2, I3)) -> InputList<'a> {
856        InputList::from(&[a.into(), b.into(), c.into()])
857    }
858}
859
860impl<'a> Extend<ValueView<'a>> for InputList<'a> {
861    fn extend<T>(&mut self, iter: T)
862    where
863        T: IntoIterator<Item = ValueView<'a>>,
864    {
865        for item in iter {
866            self.push(item);
867        }
868    }
869}
870
871impl<'a> Extend<Option<ValueView<'a>>> for InputList<'a> {
872    fn extend<T>(&mut self, iter: T)
873    where
874        T: IntoIterator<Item = Option<ValueView<'a>>>,
875    {
876        for item in iter {
877            self.push_optional(item);
878        }
879    }
880}
881
882impl<'a, A> FromIterator<A> for InputList<'a>
883where
884    InputList<'a>: Extend<A>,
885{
886    fn from_iter<T>(iter: T) -> Self
887    where
888        T: IntoIterator<Item = A>,
889    {
890        let mut list = InputList::new();
891        list.extend(iter);
892        list
893    }
894}
895
896/// Resolve an index given as a value in `[-len, len-1]` to a positive index in
897/// `[0, len)`, or return None if the index is out of bounds.
898fn resolve_index(len: usize, index: isize) -> Option<usize> {
899    let len = len as isize;
900    if index < -len || index >= len {
901        return None;
902    }
903
904    if index >= 0 {
905        Some(index as usize)
906    } else {
907        Some((len + index) as usize)
908    }
909}
910
911/// Resolve an axis given as a value in `[-ndim, ndim-1]` to the zero-based
912/// dimension of a tensor with `ndim` dimensions.
913///
914/// Negative axis values count backwards from the last dimension.
915fn resolve_axis(ndim: usize, axis: isize) -> Result<usize, OpError> {
916    resolve_index(ndim, axis).ok_or(OpError::InvalidValue("Axis is invalid"))
917}
918
919/// Resolve a sequence of axes values in `[-ndim, ndim-1]` to zero-based dimension
920/// indexes in a tensor with `ndim` dimensions.
921///
922/// Negative axis values count backwards from the last dimension.
923pub fn resolve_axes<'a, I: ExactSizeIterator<Item = &'a i32>>(
924    ndim: usize,
925    axes: I,
926) -> Result<SmallVec<[usize; 4]>, OpError> {
927    let mut resolved_axes = SmallVec::with_capacity(axes.len());
928    for axis in axes {
929        let resolved = resolve_axis(ndim, *axis as isize)?;
930        resolved_axes.push(resolved);
931    }
932    Ok(resolved_axes)
933}
934
935/// Extract a typed tensor view from a [`ValueView`] and pass it to a block.
936///
937/// The result of the macro is the result of the block, hence the block must
938/// return a value of the same type regardless of the input type. This result
939/// type must be a `Result<_, OpError>`.
940///
941/// A list of supported tensor types can optionally be specified, as a list of
942/// [`ValueView`] variant names.
943///
944/// Only tensor types are currently supported. For sequence types this always
945/// returns an error.
946macro_rules! map_value_view {
947    ($input:expr, $typed_input:ident, $block:tt) => {
948        match $input {
949            ValueView::FloatTensor($typed_input) => $block,
950            ValueView::Int32Tensor($typed_input) => $block,
951            ValueView::UInt8Tensor($typed_input) => $block,
952            ValueView::Int8Tensor($typed_input) => $block,
953            ValueView::Sequence(_) => Err(OpError::UnsupportedType)
954        }
955    };
956
957    ($input:expr, $typed_input:ident, [$($variant:ident),+], $block:tt) => {
958            match $input {
959                $(ValueView::$variant($typed_input) => $block),+,
960                _ => {
961                    return Err(OpError::UnsupportedType);
962                }
963            }
964    };
965}
966
967use map_value_view;
968
969/// Evaluate a block with a type alias defined that matches a [`DataType`].
970///
971/// For example if `$dtype` is [`DataType::Int32`] then the block will be
972/// evaluated with a type named `$type` in scope which is an alias for `i32`.
973macro_rules! map_dtype {
974    ($dtype:expr, $type:ident, $block:tt) => {{
975        use $crate::ops::DataType;
976
977        match $dtype {
978            DataType::Int32 => {
979                type $type = i32;
980                $block
981            }
982            DataType::Float => {
983                type $type = f32;
984                $block
985            }
986            DataType::UInt8 => {
987                type $type = u8;
988                $block
989            }
990            DataType::Int8 => {
991                type $type = i8;
992                $block
993            }
994        }
995    }};
996}
997
998use map_dtype;
999
1000/// Extract a typed owned tensor from a [`Value`] and pass it to a block.
1001///
1002/// The result of the macro is the result of the block, hence the block must
1003/// return a value of the same type regardless of the input type. This result
1004/// type must be a `Result<_, OpError>`.
1005///
1006/// A list of supported tensor types can optionally be specified, as a list of
1007/// [`Value`] variant names.
1008///
1009/// Only tensor types are currently supported. For sequence types this always
1010/// returns an error.
1011macro_rules! map_value {
1012    ($input:expr, $typed_input:ident, $block:tt) => {
1013        match $input {
1014            #[allow(unused_mut)]
1015            Value::FloatTensor(mut $typed_input) => $block,
1016            #[allow(unused_mut)]
1017            Value::Int32Tensor(mut $typed_input) => $block,
1018            #[allow(unused_mut)]
1019            Value::UInt8Tensor(mut $typed_input) => $block,
1020            #[allow(unused_mut)]
1021            Value::Int8Tensor(mut $typed_input) => $block,
1022            Value::Sequence(_) => Err(OpError::UnsupportedType),
1023        }
1024    };
1025
1026    ($input:expr, $typed_input:ident, [$($variant:ident),+], $block:tt) => {
1027            match $input {
1028                $(
1029                    #[allow(unused_mut)]
1030                    Value::$variant(mut $typed_input) => $block
1031                ),+,
1032                _ => {
1033                    return Err(OpError::UnsupportedType);
1034                }
1035            }
1036    };
1037}
1038
1039use map_value;
1040
1041/// Check that an operator input or attribute is valid or return an [`OpError`]
1042/// if not.
1043///
1044/// This is similar to [`assert`] but it returns an error instead of panicking
1045/// if the condition evaluates to false.
1046macro_rules! check_value {
1047    ($condition:expr, $err_variant:ident, $err_msg:expr) => {
1048        if !$condition {
1049            return Err(OpError::$err_variant($err_msg));
1050        }
1051    };
1052}
1053
1054use check_value;
1055
1056#[cfg(test)]
1057mod tests {
1058    use rten_tensor::prelude::*;
1059    use rten_tensor::test_util::{ExpectEqualError, expect_equal_with_tolerance};
1060    use rten_tensor::{NdTensor, Tensor, TensorView};
1061
1062    use super::Operator;
1063    use crate::buffer_pool::BufferPool;
1064    use crate::ops::{Add, InputList, OpError, Sub};
1065
1066    /// Create an empty tensor pool.
1067    ///
1068    /// This is a wrapper that provides a place to customize the behavior of
1069    /// the pool in tests.
1070    pub fn new_pool() -> BufferPool {
1071        BufferPool::new()
1072    }
1073
1074    /// Compare two f32 tensors with a higher absolute tolerance (1e-4) than
1075    /// the default (1e-5).
1076    ///
1077    /// Tests that use this generally ought to use a lower tolerance, but
1078    /// their test expectations will often need updating to a higher precision.
1079    pub fn expect_eq_1e4<V: AsView<Elem = f32>>(
1080        result: &V,
1081        expected: &V,
1082    ) -> Result<(), ExpectEqualError> {
1083        expect_equal_with_tolerance(result, expected, 1e-4, 0.)
1084    }
1085
1086    /// Increase the rank of a tensor by inserting leading 1-sized dimensions.
1087    pub trait IntoNDim<const N: usize> {
1088        /// Variant of `Self` with N dimensions.
1089        type Output;
1090
1091        /// Insert leading 1-sized dimensions into the shape of `self` so that
1092        /// it has N dimensions.
1093        ///
1094        /// Panics if `self` already has more than N dimensions.
1095        fn into_ndim(self) -> Self::Output;
1096    }
1097
1098    impl<T: Clone, const M: usize, const N: usize> IntoNDim<N> for NdTensor<T, M> {
1099        type Output = NdTensor<T, N>;
1100
1101        fn into_ndim(self) -> Self::Output {
1102            assert!(N >= M);
1103            let new_dims = N - M;
1104            let shape = self.shape();
1105            let new_shape =
1106                std::array::from_fn(|d| if d < new_dims { 1 } else { shape[d - new_dims] });
1107            self.into_shape(new_shape)
1108        }
1109    }
1110
1111    #[test]
1112    fn test_input_list_first_input_omitted() {
1113        let tensor = Tensor::<f32>::zeros(&[2, 2]);
1114
1115        let inputs = InputList::from(&[tensor.view().into()]).with_first_input_omitted(false);
1116        let err = inputs.require_as::<TensorView<i32>>(0).err().unwrap();
1117        assert!(matches!(err, OpError::InputCastFailed { index: 0, .. }));
1118
1119        let inputs = InputList::from(&[tensor.view().into()]).with_first_input_omitted(true);
1120        let err = inputs.require_as::<TensorView<i32>>(0).err().unwrap();
1121        assert!(matches!(err, OpError::InputCastFailed { index: 1, .. }));
1122    }
1123
1124    #[test]
1125    fn test_downcast_operator() {
1126        let add_op = Add {};
1127        let sub_op = Sub {};
1128
1129        let add_op_dyn: &dyn Operator = &add_op;
1130        let sub_op_dyn: &dyn Operator = &sub_op;
1131
1132        assert!(add_op_dyn.downcast_ref::<Add>().is_some());
1133        assert!(sub_op_dyn.downcast_ref::<Sub>().is_some());
1134    }
1135}