use std::fmt::Debug;
use rten_shape_inference::ops as shape_ops;
use smallvec::SmallVec;
use crate::operator::OpError;
use crate::value::DataType;
mod attention;
mod binary_elementwise;
mod concat;
mod control_flow;
mod conv;
mod conv_transpose;
mod convert;
mod einsum;
mod gather;
mod generate;
mod grid_sample;
mod identity;
mod layout;
mod matmul;
mod non_max_suppression;
mod norm;
mod pad;
mod pooling;
mod quantize;
#[cfg(feature = "fft")]
mod fft;
#[cfg(feature = "random")]
mod random;
mod reduce;
mod resize;
mod rnn;
mod sequence;
mod slice;
mod split;
mod trilu;
mod unary_elementwise;
mod variadic_elementwise;
pub(crate) mod transform_inputs;
#[cfg(feature = "fft")]
pub(crate) use fft::STFT;
#[cfg(feature = "random")]
pub(crate) use random::{
Dropout, RandomNormal, RandomNormalLike, RandomUniform, RandomUniformLike,
};
pub(crate) use {
attention::{AddSoftmax, GroupedQueryAttentionMatMul, RepeatInterleave},
binary_elementwise::{
Add, And, Div, Equal, Greater, GreaterOrEqual, Less, LessOrEqual, Mod, Mul, Or, Pow, Sub,
Where, Xor,
},
concat::{Concat, Tile},
control_flow::{If, Loop},
conv::{Conv, ConvInteger},
conv_transpose::ConvTranspose,
convert::{Cast, CastLike},
einsum::Einsum,
gather::{Gather, GatherElements, GatherND, ScatterElements, ScatterND, ScatterReduction},
generate::{ConstantOfShape, EyeLike, OneHot, Range},
grid_sample::GridSample,
identity::Identity,
layout::{
ComputeShape, DepthToSpace, DimSpec, Expand, Flatten, Reshape, Shape, Size, Squeeze,
Transpose, Unsqueeze,
},
matmul::{
AccuracyLevel, FusedMatMul, Gemm, MatMul, MatMulInteger, MatMulIntegerToFloat, MatMulNBits,
},
non_max_suppression::NonMaxSuppression,
norm::{
BatchNormalization, InstanceNormalization, LayerNormalization, LogSoftmax,
RmsNormalization, Softmax,
},
pad::Pad,
pooling::{AveragePool, GlobalAveragePool, GlobalMaxPool, MaxPool},
quantize::{DequantizeLinear, DynamicQuantizeLinear, QuantizeLinear},
reduce::{
ArgMax, ArgMin, CumSum, NonZero, ReduceL2, ReduceMax, ReduceMean, ReduceMin, ReduceProd,
ReduceSum, ReduceSumSquare, TopK,
},
resize::Resize,
rnn::{GRU, LSTM},
sequence::{
ConcatFromSequence, SequenceAt, SequenceConstruct, SequenceEmpty, SequenceErase,
SequenceInsert, SequenceLength, SplitToSequence,
},
slice::Slice,
split::Split,
trilu::Trilu,
unary_elementwise::{
Abs, Acos, Asin, Atan, Ceil, Clip, Cos, Elu, Erf, Exp, Floor, Gelu, HardSigmoid, HardSwish,
IsInf, IsNaN, LeakyRelu, Log, Neg, Not, PRelu, Reciprocal, Relu, Round, Sigmoid, Sign,
Silu, Sin, Softplus, Sqrt, Swish, Tan, Tanh,
},
variadic_elementwise::{Max, Mean, Min, Sum},
};
pub use binary_elementwise::{
DivMode, add, and, div, equal, greater, greater_or_equal, less, less_or_equal, mod_op, mul, or,
pow, sub, where_op, xor,
};
pub use concat::{concat, tile};
pub use conv::{conv, conv_integer};
pub use conv_transpose::conv_transpose;
pub use einsum::einsum;
pub use gather::{gather, gather_elements, gather_nd, scatter_elements, scatter_nd};
pub use generate::{constant_of_shape, onehot, range};
pub use layout::{DepthToSpaceMode, depth_to_space, expand, flatten, reshape, squeeze};
pub use matmul::{gemm, matmul};
pub use non_max_suppression::{BoxOrder, non_max_suppression};
pub use norm::{
batch_norm, instance_normalization, layer_normalization, log_softmax, rms_normalization,
softmax,
};
pub use pad::{PadMode, pad};
pub use pooling::{average_pool, global_average_pool, max_pool};
pub use quantize::{dequantize_linear, dynamic_quantize_linear, quantize_linear};
#[cfg(feature = "fft")]
pub use fft::stft;
pub use reduce::{
arg_max, arg_min, cum_sum, nonzero, reduce_l2, reduce_max, reduce_mean, reduce_min,
reduce_prod, reduce_sum, reduce_sum_square, topk,
};
pub use resize::{CoordTransformMode, NearestMode, ResizeMode, ResizeTarget, resize, resize_image};
pub use rnn::{Direction, gru, lstm};
pub use slice::slice;
pub use split::split;
pub use trilu::trilu;
pub use variadic_elementwise::{max, mean, min, sum};
mod operators;
pub use operators::{FloatOperators, Operators};
#[derive(Clone, Debug, PartialEq)]
pub enum Padding {
Same,
Fixed(SmallVec<[usize; 4]>),
}
impl Padding {
pub fn zero<const N: usize>() -> Padding {
Padding::Fixed(SmallVec::from_elem(0, N * 2))
}
pub fn expand_1d_to_2d(&self) -> Result<Padding, OpError> {
match self {
Padding::Same => Ok(Padding::Same),
Padding::Fixed(pads) => match pads.as_slice() {
&[pad_start, pad_end] => Ok([0, pad_start, 0, pad_end].into()),
_ => Err(OpError::InvalidValue("expected 2 pad values")),
},
}
}
pub fn as_shape_inference_padding(&self) -> shape_ops::Padding<'_> {
match self {
Padding::Same => shape_ops::Padding::Same,
Padding::Fixed(pads) => shape_ops::Padding::Fixed(pads),
}
}
}
impl<S: AsRef<[usize]>> From<S> for Padding {
fn from(val: S) -> Padding {
Padding::Fixed(val.as_ref().into())
}
}
fn resolve_index(len: usize, index: isize) -> Option<usize> {
let len = len as isize;
if index < -len || index >= len {
return None;
}
if index >= 0 {
Some(index as usize)
} else {
Some((len + index) as usize)
}
}
fn resolve_axis(ndim: usize, axis: isize) -> Result<usize, OpError> {
resolve_index(ndim, axis).ok_or(OpError::InvalidValue("Axis is invalid"))
}
pub fn resolve_axes<'a, I: ExactSizeIterator<Item = &'a i32>>(
ndim: usize,
axes: I,
) -> Result<SmallVec<[usize; 4]>, OpError> {
let mut resolved_axes = SmallVec::with_capacity(axes.len());
for axis in axes {
let resolved = resolve_axis(ndim, *axis as isize)?;
resolved_axes.push(resolved);
}
Ok(resolved_axes)
}
macro_rules! map_value_view {
($input:expr, $typed_input:ident, $block:tt) => {
match $input {
ValueView::FloatTensor($typed_input) => $block,
ValueView::Int32Tensor($typed_input) => $block,
ValueView::UInt8Tensor($typed_input) => $block,
ValueView::Int8Tensor($typed_input) => $block,
ValueView::Sequence(_) => Err(OpError::UnsupportedType)
}
};
($input:expr, $typed_input:ident, [$($variant:ident),+], $block:tt) => {
match $input {
$(ValueView::$variant($typed_input) => $block),+,
_ => {
return Err(OpError::UnsupportedType);
}
}
};
}
use map_value_view;
macro_rules! map_dtype {
($dtype:expr, $type:ident, $block:tt) => {{
use $crate::ops::DataType;
match $dtype {
DataType::Int32 => {
type $type = i32;
$block
}
DataType::Float => {
type $type = f32;
$block
}
DataType::UInt8 => {
type $type = u8;
$block
}
DataType::Int8 => {
type $type = i8;
$block
}
}
}};
}
use map_dtype;
macro_rules! map_value {
($input:expr, $typed_input:ident, $block:tt) => {
match $input {
#[allow(unused_mut)]
Value::FloatTensor(mut $typed_input) => $block,
#[allow(unused_mut)]
Value::Int32Tensor(mut $typed_input) => $block,
#[allow(unused_mut)]
Value::UInt8Tensor(mut $typed_input) => $block,
#[allow(unused_mut)]
Value::Int8Tensor(mut $typed_input) => $block,
Value::Sequence(_) => Err(OpError::UnsupportedType),
}
};
($input:expr, $typed_input:ident, [$($variant:ident),+], $block:tt) => {
match $input {
$(
#[allow(unused_mut)]
Value::$variant(mut $typed_input) => $block
),+,
_ => {
return Err(OpError::UnsupportedType);
}
}
};
}
use map_value;
macro_rules! check_value {
($condition:expr, $err_variant:ident, $err_msg:expr) => {
if !$condition {
return Err(OpError::$err_variant($err_msg));
}
};
}
use check_value;
#[cfg(test)]
mod tests {
use rten_tensor::NdTensor;
use rten_tensor::prelude::*;
use rten_tensor::test_util::{ExpectEqualError, expect_equal_with_tolerance};
pub fn expect_eq_1e4<V: AsView<Elem = f32>>(
result: &V,
expected: &V,
) -> Result<(), ExpectEqualError> {
expect_equal_with_tolerance(result, expected, 1e-4, 0.)
}
pub trait IntoNDim<const N: usize> {
type Output;
fn into_ndim(self) -> Self::Output;
}
impl<T: Clone, const M: usize, const N: usize> IntoNDim<N> for NdTensor<T, M> {
type Output = NdTensor<T, N>;
fn into_ndim(self) -> Self::Output {
assert!(N >= M);
let new_dims = N - M;
let shape = self.shape();
let new_shape =
std::array::from_fn(|d| if d < new_dims { 1 } else { shape[d - new_dims] });
self.into_shape(new_shape)
}
}
}