use burn_backend::tensor::IndexingUpdateOp;
use core::hash::Hash;
use serde::{Deserialize, Serialize};
use alloc::borrow::ToOwned;
use alloc::boxed::Box;
use alloc::{string::String, vec::Vec};
use burn_backend::{
DType, Distribution, Slice,
ops::{
ConvOptions, ConvTransposeOptions, DeformConvOptions, GridSampleOptions,
GridSamplePaddingMode, InterpolateMode, InterpolateOptions,
},
quantization::QuantScheme,
};
use crate::{ScalarIr, TensorId, TensorIr, TensorStatus};
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub struct CustomOpIr {
pub id: String,
pub inputs: Vec<TensorIr>,
pub outputs: Vec<TensorIr>,
}
impl CustomOpIr {
pub fn new(id: &'static str, inputs: &[TensorIr], outputs: &[TensorIr]) -> Self {
Self {
id: id.to_owned(),
inputs: inputs.to_vec(),
outputs: outputs.to_vec(),
}
}
pub fn as_fixed<const N_IN: usize, const N_OUT: usize>(
&self,
) -> (&[TensorIr; N_IN], &[TensorIr; N_OUT]) {
(
self.inputs.as_slice().try_into().expect(
"Wrong number of inputs expected (expected {D}, is {}), check your implementation",
),
self.outputs.as_slice().try_into().expect(
"Wrong number of outputs expected (expected {D}, is {}), check your implementation",
),
)
}
fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
Box::new(self.inputs.iter())
}
fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
Box::new(self.outputs.iter())
}
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub enum OperationIr {
BaseFloat(BaseOperationIr),
BaseInt(BaseOperationIr),
BaseBool(BaseOperationIr),
NumericFloat(DType, NumericOperationIr),
NumericInt(DType, NumericOperationIr),
Bool(BoolOperationIr),
Int(IntOperationIr),
Float(DType, FloatOperationIr),
Module(ModuleOperationIr),
Init(InitOperationIr),
Custom(CustomOpIr),
Drop(TensorIr),
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub enum FloatOperationIr {
Exp(UnaryOpIr),
Log(UnaryOpIr),
Log1p(UnaryOpIr),
Erf(UnaryOpIr),
PowfScalar(ScalarOpIr),
Sqrt(UnaryOpIr),
Cos(UnaryOpIr),
Cosh(UnaryOpIr),
Sin(UnaryOpIr),
Sinh(UnaryOpIr),
Tan(UnaryOpIr),
Tanh(UnaryOpIr),
ArcCos(UnaryOpIr),
ArcCosh(UnaryOpIr),
ArcSin(UnaryOpIr),
ArcSinh(UnaryOpIr),
ArcTan(UnaryOpIr),
ArcTanh(UnaryOpIr),
ArcTan2(BinaryOpIr),
Round(UnaryOpIr),
Floor(UnaryOpIr),
Ceil(UnaryOpIr),
Trunc(UnaryOpIr),
IntoInt(CastOpIr),
Matmul(MatmulOpIr),
Cross(CrossOpIr),
Random(RandomOpIr),
Recip(UnaryOpIr),
IsNan(UnaryOpIr),
IsInf(UnaryOpIr),
Quantize(QuantizeOpIr),
Dequantize(DequantizeOpIr),
GridSample2d(GridSample2dOpIr),
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub enum ModuleOperationIr {
Embedding(EmbeddingOpIr),
EmbeddingBackward(EmbeddingBackwardOpIr),
Conv1d(Conv1dOpIr),
Conv2d(Conv2dOpIr),
Conv3d(Conv3dOpIr),
DeformableConv2d(Box<DeformConv2dOpIr>),
DeformableConv2dBackward(Box<DeformConv2dBackwardOpIr>),
ConvTranspose1d(ConvTranspose1dOpIr),
ConvTranspose2d(ConvTranspose2dOpIr),
ConvTranspose3d(ConvTranspose3dOpIr),
AvgPool1d(AvgPool1dOpIr),
AvgPool2d(AvgPool2dOpIr),
AvgPool1dBackward(AvgPool1dBackwardOpIr),
AvgPool2dBackward(AvgPool2dBackwardOpIr),
AdaptiveAvgPool1d(AdaptiveAvgPool1dOpIr),
AdaptiveAvgPool2d(AdaptiveAvgPool2dOpIr),
AdaptiveAvgPool1dBackward(AdaptiveAvgPool1dBackwardOpIr),
AdaptiveAvgPool2dBackward(AdaptiveAvgPool2dBackwardOpIr),
MaxPool1d(MaxPool1dOpIr),
MaxPool1dWithIndices(MaxPool1dWithIndicesOpIr),
MaxPool1dWithIndicesBackward(MaxPool1dWithIndicesBackwardOpIr),
MaxPool2d(MaxPool2dOpIr),
MaxPool2dWithIndices(MaxPool2dWithIndicesOpIr),
MaxPool2dWithIndicesBackward(MaxPool2dWithIndicesBackwardOpIr),
Interpolate(InterpolateOpIr),
InterpolateBackward(InterpolateBackwardOpIr),
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub enum BaseOperationIr {
Reshape(ShapeOpIr),
SwapDims(SwapDimsOpIr),
Permute(PermuteOpIr),
Flip(FlipOpIr),
Expand(ShapeOpIr),
Unfold(UnfoldOpIr),
Slice(SliceOpIr),
SliceAssign(SliceAssignOpIr),
Select(SelectOpIr),
SelectAssign(SelectAssignOpIr),
MaskWhere(MaskWhereOpIr),
MaskFill(MaskFillOpIr),
Gather(GatherOpIr),
Scatter(ScatterOpIr),
Equal(BinaryOpIr),
EqualElem(ScalarOpIr),
RepeatDim(RepeatDimOpIr),
Cat(CatOpIr),
Cast(CastOpIr),
Empty(CreationOpIr),
Ones(CreationOpIr),
Zeros(CreationOpIr),
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub enum NumericOperationIr {
Add(BinaryOpIr),
AddScalar(ScalarOpIr),
Sub(BinaryOpIr),
SubScalar(ScalarOpIr),
Div(BinaryOpIr),
DivScalar(ScalarOpIr),
Rem(BinaryOpIr),
RemScalar(ScalarOpIr),
Mul(BinaryOpIr),
MulScalar(ScalarOpIr),
Abs(UnaryOpIr),
Full(FullOpIr),
MeanDim(ReduceDimOpIr),
Mean(ReduceOpIr),
Sum(ReduceOpIr),
SumDim(ReduceDimOpIr),
Prod(ReduceOpIr),
ProdDim(ReduceDimOpIr),
Greater(BinaryOpIr),
GreaterElem(ScalarOpIr),
GreaterEqual(BinaryOpIr),
GreaterEqualElem(ScalarOpIr),
Lower(BinaryOpIr),
LowerElem(ScalarOpIr),
LowerEqual(BinaryOpIr),
LowerEqualElem(ScalarOpIr),
ArgMax(ReduceDimOpIr),
ArgMin(ReduceDimOpIr),
Max(ReduceOpIr),
MaxDimWithIndices(ReduceDimWithIndicesOpIr),
MinDimWithIndices(ReduceDimWithIndicesOpIr),
Min(ReduceOpIr),
MaxDim(ReduceDimOpIr),
MinDim(ReduceDimOpIr),
MaxAbs(ReduceOpIr),
MaxAbsDim(ReduceDimOpIr),
Clamp(ClampOpIr),
IntRandom(RandomOpIr),
Powf(BinaryOpIr),
CumSum(DimOpIr),
CumProd(DimOpIr),
CumMin(DimOpIr),
CumMax(DimOpIr),
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub enum IntOperationIr {
IntoFloat(CastOpIr),
BitwiseAnd(BinaryOpIr),
BitwiseAndScalar(ScalarOpIr),
BitwiseOr(BinaryOpIr),
BitwiseOrScalar(ScalarOpIr),
BitwiseXor(BinaryOpIr),
BitwiseXorScalar(ScalarOpIr),
BitwiseNot(UnaryOpIr),
BitwiseLeftShift(BinaryOpIr),
BitwiseLeftShiftScalar(ScalarOpIr),
BitwiseRightShift(BinaryOpIr),
BitwiseRightShiftScalar(ScalarOpIr),
Matmul(MatmulOpIr),
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub enum BoolOperationIr {
IntoFloat(CastOpIr),
IntoInt(CastOpIr),
Not(UnaryOpIr),
And(BinaryOpIr),
Or(BinaryOpIr),
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub struct SwapDimsOpIr {
pub input: TensorIr,
pub out: TensorIr,
pub dim1: usize,
pub dim2: usize,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub struct PermuteOpIr {
pub input: TensorIr,
pub out: TensorIr,
pub axes: Vec<usize>,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub struct ShapeOpIr {
pub input: TensorIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub struct UnfoldOpIr {
pub input: TensorIr,
pub out: TensorIr,
pub dim: usize,
pub size: usize,
pub step: usize,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub struct FlipOpIr {
pub input: TensorIr,
pub out: TensorIr,
pub axes: Vec<usize>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct RandomOpIr {
pub out: TensorIr,
pub distribution: Distribution,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub struct CreationOpIr {
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub struct FullOpIr {
pub out: TensorIr,
pub value: ScalarIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub struct InitOperationIr {
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct BinaryOpIr {
pub lhs: TensorIr,
pub rhs: TensorIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct MatmulOpIr {
pub lhs: TensorIr,
pub rhs: TensorIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct CrossOpIr {
pub lhs: TensorIr,
pub rhs: TensorIr,
pub out: TensorIr,
pub dim: usize,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct UnaryOpIr {
pub input: TensorIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct ScalarOpIr {
pub lhs: TensorIr,
pub rhs: ScalarIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]
#[allow(missing_docs)]
pub struct ReduceOpIr {
pub input: TensorIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]
#[allow(missing_docs)]
pub struct ReduceDimOpIr {
pub input: TensorIr,
pub out: TensorIr,
pub axis: usize,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct CastOpIr {
pub input: TensorIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]
#[allow(missing_docs)]
pub struct DimOpIr {
pub input: TensorIr,
pub out: TensorIr,
pub axis: usize,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct GatherOpIr {
pub tensor: TensorIr,
pub dim: usize,
pub indices: TensorIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct ScatterOpIr {
pub tensor: TensorIr,
pub dim: usize,
pub indices: TensorIr,
pub value: TensorIr,
pub update: IndexingUpdateOp,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct SelectOpIr {
pub tensor: TensorIr,
pub dim: usize,
pub indices: TensorIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct SelectAssignOpIr {
pub tensor: TensorIr,
pub dim: usize,
pub indices: TensorIr,
pub value: TensorIr,
pub update: IndexingUpdateOp,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct SliceOpIr {
pub tensor: TensorIr,
pub ranges: Vec<Slice>,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct SliceAssignOpIr {
pub tensor: TensorIr,
pub ranges: Vec<burn_backend::Slice>,
pub value: TensorIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct MaskWhereOpIr {
pub tensor: TensorIr,
pub mask: TensorIr,
pub value: TensorIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct MaskFillOpIr {
pub tensor: TensorIr,
pub mask: TensorIr,
pub value: ScalarIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct ClampOpIr {
pub tensor: TensorIr,
pub min: ScalarIr,
pub max: ScalarIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct RepeatDimOpIr {
pub tensor: TensorIr,
pub dim: usize,
pub times: usize,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct CatOpIr {
pub tensors: Vec<TensorIr>,
pub dim: usize,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct ReduceDimWithIndicesOpIr {
pub tensor: TensorIr,
pub dim: usize,
pub out: TensorIr,
pub out_indices: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct EmbeddingOpIr {
pub weights: TensorIr,
pub indices: TensorIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct EmbeddingBackwardOpIr {
pub weights: TensorIr,
pub out_grad: TensorIr,
pub indices: TensorIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct Conv1dOpIr {
pub x: TensorIr,
pub weight: TensorIr,
pub bias: Option<TensorIr>,
pub options: Conv1dOptionsIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct Conv2dOpIr {
pub x: TensorIr,
pub weight: TensorIr,
pub bias: Option<TensorIr>,
pub options: Conv2dOptionsIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct DeformConv2dOpIr {
pub x: TensorIr,
pub offset: TensorIr,
pub weight: TensorIr,
pub mask: Option<TensorIr>,
pub bias: Option<TensorIr>,
pub options: DeformableConv2dOptionsIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct DeformConv2dBackwardOpIr {
pub x: TensorIr,
pub offset: TensorIr,
pub weight: TensorIr,
pub mask: Option<TensorIr>,
pub bias: Option<TensorIr>,
pub out_grad: TensorIr,
pub options: DeformableConv2dOptionsIr,
pub input_grad: TensorIr,
pub offset_grad: TensorIr,
pub weight_grad: TensorIr,
pub mask_grad: Option<TensorIr>,
pub bias_grad: Option<TensorIr>,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct Conv3dOpIr {
pub x: TensorIr,
pub weight: TensorIr,
pub bias: Option<TensorIr>,
pub options: Conv3dOptionsIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct ConvTranspose1dOpIr {
pub x: TensorIr,
pub weight: TensorIr,
pub bias: Option<TensorIr>,
pub options: ConvTranspose1dOptionsIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct ConvTranspose2dOpIr {
pub x: TensorIr,
pub weight: TensorIr,
pub bias: Option<TensorIr>,
pub options: ConvTranspose2dOptionsIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct ConvTranspose3dOpIr {
pub x: TensorIr,
pub weight: TensorIr,
pub bias: Option<TensorIr>,
pub options: ConvTranspose3dOptionsIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct Conv1dOptionsIr {
pub stride: [usize; 1],
pub padding: [usize; 1],
pub dilation: [usize; 1],
pub groups: usize,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct Conv2dOptionsIr {
pub stride: [usize; 2],
pub padding: [usize; 2],
pub dilation: [usize; 2],
pub groups: usize,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct DeformableConv2dOptionsIr {
pub stride: [usize; 2],
pub padding: [usize; 2],
pub dilation: [usize; 2],
pub weight_groups: usize,
pub offset_groups: usize,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct Conv3dOptionsIr {
pub stride: [usize; 3],
pub padding: [usize; 3],
pub dilation: [usize; 3],
pub groups: usize,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct ConvTranspose1dOptionsIr {
pub stride: [usize; 1],
pub padding: [usize; 1],
pub padding_out: [usize; 1],
pub dilation: [usize; 1],
pub groups: usize,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct ConvTranspose2dOptionsIr {
pub stride: [usize; 2],
pub padding: [usize; 2],
pub padding_out: [usize; 2],
pub dilation: [usize; 2],
pub groups: usize,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct ConvTranspose3dOptionsIr {
pub stride: [usize; 3],
pub padding: [usize; 3],
pub padding_out: [usize; 3],
pub dilation: [usize; 3],
pub groups: usize,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct QuantizationParametersIr {
pub scales: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct QuantizeOpIr {
pub tensor: TensorIr,
pub qparams: QuantizationParametersIr,
pub scheme: QuantScheme,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct DequantizeOpIr {
pub input: TensorIr,
pub out: TensorIr,
}
impl From<ConvOptions<1>> for Conv1dOptionsIr {
fn from(value: ConvOptions<1>) -> Self {
Self {
stride: value.stride,
padding: value.padding,
dilation: value.dilation,
groups: value.groups,
}
}
}
impl From<ConvOptions<2>> for Conv2dOptionsIr {
fn from(value: ConvOptions<2>) -> Self {
Self {
stride: value.stride,
padding: value.padding,
dilation: value.dilation,
groups: value.groups,
}
}
}
impl From<ConvOptions<3>> for Conv3dOptionsIr {
fn from(value: ConvOptions<3>) -> Self {
Self {
stride: value.stride,
padding: value.padding,
dilation: value.dilation,
groups: value.groups,
}
}
}
impl From<DeformConvOptions<2>> for DeformableConv2dOptionsIr {
fn from(value: DeformConvOptions<2>) -> Self {
Self {
stride: value.stride,
padding: value.padding,
dilation: value.dilation,
weight_groups: value.weight_groups,
offset_groups: value.offset_groups,
}
}
}
impl From<ConvTransposeOptions<1>> for ConvTranspose1dOptionsIr {
fn from(value: ConvTransposeOptions<1>) -> Self {
Self {
stride: value.stride,
padding: value.padding,
padding_out: value.padding_out,
dilation: value.dilation,
groups: value.groups,
}
}
}
impl From<ConvTransposeOptions<2>> for ConvTranspose2dOptionsIr {
fn from(value: ConvTransposeOptions<2>) -> Self {
Self {
stride: value.stride,
padding: value.padding,
padding_out: value.padding_out,
dilation: value.dilation,
groups: value.groups,
}
}
}
impl From<ConvTransposeOptions<3>> for ConvTranspose3dOptionsIr {
fn from(value: ConvTransposeOptions<3>) -> Self {
Self {
stride: value.stride,
padding: value.padding,
padding_out: value.padding_out,
dilation: value.dilation,
groups: value.groups,
}
}
}
impl From<Conv1dOptionsIr> for ConvOptions<1> {
fn from(val: Conv1dOptionsIr) -> Self {
ConvOptions {
stride: val.stride,
padding: val.padding,
dilation: val.dilation,
groups: val.groups,
}
}
}
impl From<Conv2dOptionsIr> for ConvOptions<2> {
fn from(val: Conv2dOptionsIr) -> Self {
ConvOptions {
stride: val.stride,
padding: val.padding,
dilation: val.dilation,
groups: val.groups,
}
}
}
impl From<Conv3dOptionsIr> for ConvOptions<3> {
fn from(val: Conv3dOptionsIr) -> Self {
ConvOptions {
stride: val.stride,
padding: val.padding,
dilation: val.dilation,
groups: val.groups,
}
}
}
impl From<DeformableConv2dOptionsIr> for DeformConvOptions<2> {
fn from(value: DeformableConv2dOptionsIr) -> Self {
DeformConvOptions {
stride: value.stride,
padding: value.padding,
dilation: value.dilation,
weight_groups: value.weight_groups,
offset_groups: value.offset_groups,
}
}
}
impl From<ConvTranspose1dOptionsIr> for ConvTransposeOptions<1> {
fn from(val: ConvTranspose1dOptionsIr) -> Self {
ConvTransposeOptions {
stride: val.stride,
padding: val.padding,
padding_out: val.padding_out,
dilation: val.dilation,
groups: val.groups,
}
}
}
impl From<ConvTranspose2dOptionsIr> for ConvTransposeOptions<2> {
fn from(val: ConvTranspose2dOptionsIr) -> Self {
ConvTransposeOptions {
stride: val.stride,
padding: val.padding,
padding_out: val.padding_out,
dilation: val.dilation,
groups: val.groups,
}
}
}
impl From<ConvTranspose3dOptionsIr> for ConvTransposeOptions<3> {
fn from(val: ConvTranspose3dOptionsIr) -> Self {
ConvTransposeOptions {
stride: val.stride,
padding: val.padding,
padding_out: val.padding_out,
dilation: val.dilation,
groups: val.groups,
}
}
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct AvgPool1dOpIr {
pub x: TensorIr,
pub kernel_size: usize,
pub stride: usize,
pub padding: usize,
pub count_include_pad: bool,
pub ceil_mode: bool,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct AvgPool2dOpIr {
pub x: TensorIr,
pub kernel_size: [usize; 2],
pub stride: [usize; 2],
pub padding: [usize; 2],
pub count_include_pad: bool,
pub ceil_mode: bool,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct AvgPool1dBackwardOpIr {
pub x: TensorIr,
pub grad: TensorIr,
pub kernel_size: usize,
pub stride: usize,
pub padding: usize,
pub count_include_pad: bool,
pub ceil_mode: bool,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct AvgPool2dBackwardOpIr {
pub x: TensorIr,
pub grad: TensorIr,
pub kernel_size: [usize; 2],
pub stride: [usize; 2],
pub padding: [usize; 2],
pub count_include_pad: bool,
pub ceil_mode: bool,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct AdaptiveAvgPool1dOpIr {
pub x: TensorIr,
pub output_size: usize,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct AdaptiveAvgPool2dOpIr {
pub x: TensorIr,
pub output_size: [usize; 2],
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct AdaptiveAvgPool1dBackwardOpIr {
pub x: TensorIr,
pub grad: TensorIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct AdaptiveAvgPool2dBackwardOpIr {
pub x: TensorIr,
pub grad: TensorIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct MaxPool1dOpIr {
pub x: TensorIr,
pub kernel_size: usize,
pub stride: usize,
pub padding: usize,
pub dilation: usize,
pub ceil_mode: bool,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct MaxPool1dWithIndicesOpIr {
pub x: TensorIr,
pub kernel_size: usize,
pub stride: usize,
pub padding: usize,
pub dilation: usize,
pub ceil_mode: bool,
pub out: TensorIr,
pub out_indices: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct MaxPool1dWithIndicesBackwardOpIr {
pub x: TensorIr,
pub grad: TensorIr,
pub indices: TensorIr,
pub kernel_size: usize,
pub stride: usize,
pub padding: usize,
pub dilation: usize,
pub ceil_mode: bool,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct MaxPool2dOpIr {
pub x: TensorIr,
pub kernel_size: [usize; 2],
pub stride: [usize; 2],
pub padding: [usize; 2],
pub dilation: [usize; 2],
pub ceil_mode: bool,
pub out: TensorIr,
}
#[allow(missing_docs)]
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub struct MaxPool2dWithIndicesOpIr {
pub x: TensorIr,
pub kernel_size: [usize; 2],
pub stride: [usize; 2],
pub padding: [usize; 2],
pub dilation: [usize; 2],
pub ceil_mode: bool,
pub out: TensorIr,
pub out_indices: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct MaxPool2dWithIndicesBackwardOpIr {
pub x: TensorIr,
pub grad: TensorIr,
pub indices: TensorIr,
pub kernel_size: [usize; 2],
pub stride: [usize; 2],
pub padding: [usize; 2],
pub dilation: [usize; 2],
pub ceil_mode: bool,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub enum InterpolateModeIr {
Nearest,
Bilinear,
Bicubic,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct InterpolateOptionsIr {
pub mode: InterpolateModeIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct InterpolateOpIr {
pub x: TensorIr,
pub output_size: [usize; 2],
pub options: InterpolateOptionsIr,
pub out: TensorIr,
}
impl From<InterpolateModeIr> for InterpolateMode {
fn from(val: InterpolateModeIr) -> Self {
match val {
InterpolateModeIr::Nearest => Self::Nearest,
InterpolateModeIr::Bilinear => Self::Bilinear,
InterpolateModeIr::Bicubic => Self::Bicubic,
}
}
}
impl From<InterpolateOptionsIr> for InterpolateOptions {
fn from(val: InterpolateOptionsIr) -> Self {
Self {
mode: val.mode.into(),
}
}
}
impl From<InterpolateMode> for InterpolateModeIr {
fn from(val: InterpolateMode) -> Self {
match val {
InterpolateMode::Nearest => Self::Nearest,
InterpolateMode::Bilinear => Self::Bilinear,
InterpolateMode::Bicubic => Self::Bicubic,
}
}
}
impl From<InterpolateOptions> for InterpolateOptionsIr {
fn from(val: InterpolateOptions) -> Self {
Self {
mode: val.mode.into(),
}
}
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct InterpolateBackwardOpIr {
pub x: TensorIr,
pub grad: TensorIr,
pub output_size: [usize; 2],
pub options: InterpolateOptionsIr,
pub out: TensorIr,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub enum GridSamplePaddingModeIr {
Zeros,
Border,
Reflection,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct GridSampleOptionsIr {
pub mode: InterpolateModeIr,
pub padding_mode: GridSamplePaddingModeIr,
pub align_corners: bool,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct GridSample2dOpIr {
pub tensor: TensorIr,
pub grid: TensorIr,
pub options: GridSampleOptionsIr,
pub out: TensorIr,
}
impl From<GridSamplePaddingModeIr> for GridSamplePaddingMode {
fn from(val: GridSamplePaddingModeIr) -> Self {
match val {
GridSamplePaddingModeIr::Zeros => Self::Zeros,
GridSamplePaddingModeIr::Border => Self::Border,
GridSamplePaddingModeIr::Reflection => Self::Reflection,
}
}
}
impl From<GridSamplePaddingMode> for GridSamplePaddingModeIr {
fn from(val: GridSamplePaddingMode) -> Self {
match val {
GridSamplePaddingMode::Zeros => Self::Zeros,
GridSamplePaddingMode::Border => Self::Border,
GridSamplePaddingMode::Reflection => Self::Reflection,
}
}
}
impl From<GridSampleOptionsIr> for GridSampleOptions {
fn from(val: GridSampleOptionsIr) -> Self {
Self {
mode: val.mode.into(),
padding_mode: val.padding_mode.into(),
align_corners: val.align_corners,
}
}
}
impl From<GridSampleOptions> for GridSampleOptionsIr {
fn from(val: GridSampleOptions) -> Self {
Self {
mode: val.mode.into(),
padding_mode: val.padding_mode.into(),
align_corners: val.align_corners,
}
}
}
impl OperationIr {
pub fn inputs(&self) -> impl Iterator<Item = &TensorIr> {
match self {
OperationIr::BaseFloat(repr) => repr.inputs(),
OperationIr::BaseInt(repr) => repr.inputs(),
OperationIr::BaseBool(repr) => repr.inputs(),
OperationIr::NumericFloat(_dtype, repr) => repr.inputs(),
OperationIr::NumericInt(_dtype, repr) => repr.inputs(),
OperationIr::Bool(repr) => repr.inputs(),
OperationIr::Int(repr) => repr.inputs(),
OperationIr::Float(_dtype, repr) => repr.inputs(),
OperationIr::Module(repr) => repr.inputs(),
OperationIr::Init(repr) => repr.inputs(),
OperationIr::Custom(repr) => repr.inputs(),
OperationIr::Drop(repr) => Box::new([repr].into_iter()),
}
}
pub fn outputs(&self) -> impl Iterator<Item = &TensorIr> {
match self {
OperationIr::BaseFloat(repr) => repr.outputs(),
OperationIr::BaseInt(repr) => repr.outputs(),
OperationIr::BaseBool(repr) => repr.outputs(),
OperationIr::NumericFloat(_dtype, repr) => repr.outputs(),
OperationIr::NumericInt(_dtype, repr) => repr.outputs(),
OperationIr::Bool(repr) => repr.outputs(),
OperationIr::Int(repr) => repr.outputs(),
OperationIr::Float(_dtype, repr) => repr.outputs(),
OperationIr::Module(repr) => repr.outputs(),
OperationIr::Init(repr) => repr.outputs(),
OperationIr::Custom(repr) => repr.outputs(),
OperationIr::Drop(_repr) => Box::new([].into_iter()),
}
}
pub fn nodes(&self) -> Vec<&TensorIr> {
self.inputs().chain(self.outputs()).collect()
}
pub fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
match self {
OperationIr::BaseFloat(repr) => repr.mark_read_only(nodes),
OperationIr::BaseInt(repr) => repr.mark_read_only(nodes),
OperationIr::BaseBool(repr) => repr.mark_read_only(nodes),
OperationIr::NumericFloat(_dtype, repr) => repr.mark_read_only(nodes),
OperationIr::NumericInt(_dtype, repr) => repr.mark_read_only(nodes),
OperationIr::Bool(repr) => repr.mark_read_only(nodes),
OperationIr::Int(repr) => repr.mark_read_only(nodes),
OperationIr::Float(_dtype, repr) => repr.mark_read_only(nodes),
OperationIr::Module(repr) => repr.mark_read_only(nodes),
OperationIr::Init(_) => Vec::new(),
OperationIr::Drop(repr) => {
let mut output = Vec::new();
repr.mark_read_only(nodes, &mut output);
output
}
OperationIr::Custom(repr) => {
let mut output = Vec::new();
for input in repr.inputs.iter_mut() {
input.mark_read_only(nodes, &mut output);
}
output
}
}
}
}
impl BaseOperationIr {
fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
match self {
BaseOperationIr::Reshape(repr) => Box::new([&repr.input].into_iter()),
BaseOperationIr::SwapDims(repr) => Box::new([&repr.input].into_iter()),
BaseOperationIr::Permute(repr) => Box::new([&repr.input].into_iter()),
BaseOperationIr::Expand(repr) => Box::new([&repr.input].into_iter()),
BaseOperationIr::Flip(repr) => Box::new([&repr.input].into_iter()),
BaseOperationIr::Slice(repr) => Box::new([&repr.tensor].into_iter()),
BaseOperationIr::SliceAssign(repr) => Box::new([&repr.tensor, &repr.value].into_iter()),
BaseOperationIr::Gather(repr) => Box::new([&repr.tensor, &repr.indices].into_iter()),
BaseOperationIr::Scatter(repr) => {
Box::new([&repr.tensor, &repr.indices, &repr.value].into_iter())
}
BaseOperationIr::Select(repr) => Box::new([&repr.tensor, &repr.indices].into_iter()),
BaseOperationIr::SelectAssign(repr) => {
Box::new([&repr.tensor, &repr.indices, &repr.value].into_iter())
}
BaseOperationIr::MaskWhere(repr) => {
Box::new([&repr.tensor, &repr.mask, &repr.value].into_iter())
}
BaseOperationIr::MaskFill(repr) => Box::new([&repr.tensor, &repr.mask].into_iter()),
BaseOperationIr::Equal(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
BaseOperationIr::EqualElem(repr) => Box::new([&repr.lhs].into_iter()),
BaseOperationIr::RepeatDim(repr) => Box::new([&repr.tensor].into_iter()),
BaseOperationIr::Cat(repr) => Box::new(repr.tensors.iter()),
BaseOperationIr::Cast(repr) => Box::new([&repr.input].into_iter()),
BaseOperationIr::Unfold(repr) => Box::new([&repr.input].into_iter()),
BaseOperationIr::Empty(_repr) => Box::new([].into_iter()),
BaseOperationIr::Ones(_repr) => Box::new([].into_iter()),
BaseOperationIr::Zeros(_repr) => Box::new([].into_iter()),
}
}
fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
match self {
BaseOperationIr::Reshape(repr) => Box::new([&repr.out].into_iter()),
BaseOperationIr::SwapDims(repr) => Box::new([&repr.out].into_iter()),
BaseOperationIr::Permute(repr) => Box::new([&repr.out].into_iter()),
BaseOperationIr::Expand(repr) => Box::new([&repr.out].into_iter()),
BaseOperationIr::Flip(repr) => Box::new([&repr.out].into_iter()),
BaseOperationIr::Slice(repr) => Box::new([&repr.out].into_iter()),
BaseOperationIr::SliceAssign(repr) => Box::new([&repr.out].into_iter()),
BaseOperationIr::Gather(repr) => Box::new([&repr.out].into_iter()),
BaseOperationIr::Scatter(repr) => Box::new([&repr.out].into_iter()),
BaseOperationIr::Select(repr) => Box::new([&repr.out].into_iter()),
BaseOperationIr::SelectAssign(repr) => Box::new([&repr.out].into_iter()),
BaseOperationIr::MaskWhere(repr) => Box::new([&repr.out].into_iter()),
BaseOperationIr::MaskFill(repr) => Box::new([&repr.out].into_iter()),
BaseOperationIr::Equal(repr) => Box::new([&repr.out].into_iter()),
BaseOperationIr::EqualElem(repr) => Box::new([&repr.out].into_iter()),
BaseOperationIr::RepeatDim(repr) => Box::new([&repr.out].into_iter()),
BaseOperationIr::Cat(repr) => Box::new([&repr.out].into_iter()),
BaseOperationIr::Cast(repr) => Box::new([&repr.out].into_iter()),
BaseOperationIr::Unfold(repr) => Box::new([&repr.out].into_iter()),
BaseOperationIr::Empty(repr) => Box::new([&repr.out].into_iter()),
BaseOperationIr::Ones(repr) => Box::new([&repr.out].into_iter()),
BaseOperationIr::Zeros(repr) => Box::new([&repr.out].into_iter()),
}
}
fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
let mut output = Vec::new();
match self {
BaseOperationIr::Reshape(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
BaseOperationIr::SwapDims(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
BaseOperationIr::Permute(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
BaseOperationIr::Expand(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
BaseOperationIr::Flip(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
BaseOperationIr::Slice(repr) => {
repr.tensor.mark_read_only(nodes, &mut output);
}
BaseOperationIr::SliceAssign(repr) => {
repr.tensor.mark_read_only(nodes, &mut output);
repr.value.mark_read_only(nodes, &mut output);
}
BaseOperationIr::Gather(repr) => {
repr.tensor.mark_read_only(nodes, &mut output);
repr.indices.mark_read_only(nodes, &mut output);
}
BaseOperationIr::Scatter(repr) => {
repr.tensor.mark_read_only(nodes, &mut output);
repr.indices.mark_read_only(nodes, &mut output);
repr.value.mark_read_only(nodes, &mut output);
}
BaseOperationIr::Select(repr) => {
repr.tensor.mark_read_only(nodes, &mut output);
repr.indices.mark_read_only(nodes, &mut output);
}
BaseOperationIr::SelectAssign(repr) => {
repr.tensor.mark_read_only(nodes, &mut output);
repr.indices.mark_read_only(nodes, &mut output);
repr.value.mark_read_only(nodes, &mut output);
}
BaseOperationIr::MaskWhere(repr) => {
repr.tensor.mark_read_only(nodes, &mut output);
repr.mask.mark_read_only(nodes, &mut output);
repr.value.mark_read_only(nodes, &mut output);
}
BaseOperationIr::MaskFill(repr) => {
repr.tensor.mark_read_only(nodes, &mut output);
repr.mask.mark_read_only(nodes, &mut output);
}
BaseOperationIr::Equal(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
repr.rhs.mark_read_only(nodes, &mut output);
}
BaseOperationIr::EqualElem(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
}
BaseOperationIr::RepeatDim(repr) => {
repr.tensor.mark_read_only(nodes, &mut output);
}
BaseOperationIr::Cat(repr) => {
for t in repr.tensors.iter_mut() {
t.mark_read_only(nodes, &mut output);
}
}
BaseOperationIr::Cast(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
BaseOperationIr::Unfold(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
BaseOperationIr::Empty(_) => {}
BaseOperationIr::Zeros(_) => {}
BaseOperationIr::Ones(_) => {}
};
output
}
}
impl NumericOperationIr {
fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
match self {
NumericOperationIr::Add(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
NumericOperationIr::AddScalar(repr) => Box::new([&repr.lhs].into_iter()),
NumericOperationIr::Sub(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
NumericOperationIr::SubScalar(repr) => Box::new([&repr.lhs].into_iter()),
NumericOperationIr::Mul(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
NumericOperationIr::MulScalar(repr) => Box::new([&repr.lhs].into_iter()),
NumericOperationIr::Div(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
NumericOperationIr::DivScalar(repr) => Box::new([&repr.lhs].into_iter()),
NumericOperationIr::Rem(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
NumericOperationIr::RemScalar(repr) => Box::new([&repr.lhs].into_iter()),
NumericOperationIr::GreaterElem(repr) => Box::new([&repr.lhs].into_iter()),
NumericOperationIr::GreaterEqualElem(repr) => Box::new([&repr.lhs].into_iter()),
NumericOperationIr::LowerElem(repr) => Box::new([&repr.lhs].into_iter()),
NumericOperationIr::LowerEqualElem(repr) => Box::new([&repr.lhs].into_iter()),
NumericOperationIr::Greater(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
NumericOperationIr::GreaterEqual(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
NumericOperationIr::Lower(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
NumericOperationIr::LowerEqual(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
NumericOperationIr::ArgMax(repr) => Box::new([&repr.input].into_iter()),
NumericOperationIr::ArgMin(repr) => Box::new([&repr.input].into_iter()),
NumericOperationIr::Clamp(repr) => Box::new([&repr.tensor].into_iter()),
NumericOperationIr::Abs(repr) => Box::new([&repr.input].into_iter()),
NumericOperationIr::Full(_repr) => Box::new([].into_iter()),
NumericOperationIr::MeanDim(repr) => Box::new([&repr.input].into_iter()),
NumericOperationIr::Mean(repr) => Box::new([&repr.input].into_iter()),
NumericOperationIr::Sum(repr) => Box::new([&repr.input].into_iter()),
NumericOperationIr::SumDim(repr) => Box::new([&repr.input].into_iter()),
NumericOperationIr::Prod(repr) => Box::new([&repr.input].into_iter()),
NumericOperationIr::ProdDim(repr) => Box::new([&repr.input].into_iter()),
NumericOperationIr::Max(repr) => Box::new([&repr.input].into_iter()),
NumericOperationIr::MaxDimWithIndices(repr) => Box::new([&repr.tensor].into_iter()),
NumericOperationIr::MinDimWithIndices(repr) => Box::new([&repr.tensor].into_iter()),
NumericOperationIr::Min(repr) => Box::new([&repr.input].into_iter()),
NumericOperationIr::MaxDim(repr) => Box::new([&repr.input].into_iter()),
NumericOperationIr::MinDim(repr) => Box::new([&repr.input].into_iter()),
NumericOperationIr::MaxAbs(repr) => Box::new([&repr.input].into_iter()),
NumericOperationIr::MaxAbsDim(repr) => Box::new([&repr.input].into_iter()),
NumericOperationIr::IntRandom(_repr) => Box::new([].into_iter()),
NumericOperationIr::Powf(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
NumericOperationIr::CumMin(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::CumMax(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::CumProd(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::CumSum(repr) => Box::new([&repr.out].into_iter()),
}
}
fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
match self {
NumericOperationIr::Add(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::AddScalar(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::Sub(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::SubScalar(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::Mul(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::MulScalar(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::Div(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::DivScalar(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::Rem(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::RemScalar(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::GreaterElem(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::GreaterEqualElem(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::LowerElem(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::LowerEqualElem(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::Greater(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::GreaterEqual(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::Lower(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::LowerEqual(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::ArgMax(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::ArgMin(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::Clamp(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::Abs(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::Full(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::MeanDim(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::Mean(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::Sum(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::SumDim(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::Prod(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::ProdDim(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::Max(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::MaxDimWithIndices(repr) => {
Box::new([&repr.out, &repr.out_indices].into_iter())
}
NumericOperationIr::MinDimWithIndices(repr) => {
Box::new([&repr.out, &repr.out_indices].into_iter())
}
NumericOperationIr::Min(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::MaxDim(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::MinDim(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::MaxAbs(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::MaxAbsDim(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::IntRandom(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::Powf(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::CumMin(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::CumMax(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::CumProd(repr) => Box::new([&repr.out].into_iter()),
NumericOperationIr::CumSum(repr) => Box::new([&repr.out].into_iter()),
}
}
fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
let mut output = Vec::new();
match self {
NumericOperationIr::Add(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
repr.rhs.mark_read_only(nodes, &mut output);
}
NumericOperationIr::AddScalar(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
}
NumericOperationIr::Sub(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
repr.rhs.mark_read_only(nodes, &mut output);
}
NumericOperationIr::SubScalar(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
}
NumericOperationIr::Mul(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
repr.rhs.mark_read_only(nodes, &mut output);
}
NumericOperationIr::MulScalar(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
}
NumericOperationIr::Div(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
repr.rhs.mark_read_only(nodes, &mut output);
}
NumericOperationIr::DivScalar(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
}
NumericOperationIr::Rem(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
repr.rhs.mark_read_only(nodes, &mut output);
}
NumericOperationIr::RemScalar(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
}
NumericOperationIr::GreaterElem(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
}
NumericOperationIr::GreaterEqualElem(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
}
NumericOperationIr::LowerElem(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
}
NumericOperationIr::LowerEqualElem(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
}
NumericOperationIr::Greater(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
repr.rhs.mark_read_only(nodes, &mut output);
}
NumericOperationIr::GreaterEqual(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
repr.rhs.mark_read_only(nodes, &mut output);
}
NumericOperationIr::Lower(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
repr.rhs.mark_read_only(nodes, &mut output);
}
NumericOperationIr::LowerEqual(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
repr.rhs.mark_read_only(nodes, &mut output);
}
NumericOperationIr::ArgMax(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
NumericOperationIr::ArgMin(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
NumericOperationIr::Clamp(repr) => {
repr.tensor.mark_read_only(nodes, &mut output);
}
NumericOperationIr::Abs(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
NumericOperationIr::Full(_) => {}
NumericOperationIr::MeanDim(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
NumericOperationIr::Mean(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
NumericOperationIr::Sum(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
NumericOperationIr::SumDim(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
NumericOperationIr::Prod(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
NumericOperationIr::ProdDim(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
NumericOperationIr::Max(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
NumericOperationIr::MaxDimWithIndices(repr) => {
repr.tensor.mark_read_only(nodes, &mut output);
}
NumericOperationIr::MinDimWithIndices(repr) => {
repr.tensor.mark_read_only(nodes, &mut output);
}
NumericOperationIr::Min(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
NumericOperationIr::MaxDim(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
NumericOperationIr::MinDim(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
NumericOperationIr::MaxAbs(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
NumericOperationIr::MaxAbsDim(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
NumericOperationIr::IntRandom(_) => {}
NumericOperationIr::Powf(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
repr.rhs.mark_read_only(nodes, &mut output);
}
NumericOperationIr::CumSum(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
NumericOperationIr::CumProd(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
NumericOperationIr::CumMin(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
NumericOperationIr::CumMax(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
};
output
}
}
impl FloatOperationIr {
fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
match self {
FloatOperationIr::Matmul(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
FloatOperationIr::Cross(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
FloatOperationIr::Random(_repr) => Box::new([].into_iter()),
FloatOperationIr::Exp(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::Log(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::Log1p(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::Erf(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::Recip(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::PowfScalar(repr) => Box::new([&repr.lhs].into_iter()),
FloatOperationIr::Sqrt(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::Cos(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::Sin(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::Tanh(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::Round(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::Floor(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::Ceil(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::Trunc(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::IntoInt(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::Quantize(repr) => {
Box::new([&repr.tensor, &repr.qparams.scales].into_iter())
}
FloatOperationIr::Dequantize(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::IsNan(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::IsInf(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::GridSample2d(repr) => {
Box::new([&repr.tensor, &repr.grid].into_iter())
}
FloatOperationIr::Tan(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::Cosh(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::Sinh(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::ArcCos(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::ArcCosh(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::ArcSin(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::ArcSinh(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::ArcTan(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::ArcTanh(repr) => Box::new([&repr.input].into_iter()),
FloatOperationIr::ArcTan2(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
}
}
fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
match self {
FloatOperationIr::Matmul(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::Cross(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::Random(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::Exp(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::Log(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::Log1p(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::Erf(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::Recip(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::PowfScalar(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::Sqrt(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::Cos(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::Sin(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::Tanh(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::Round(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::Floor(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::Ceil(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::Trunc(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::IntoInt(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::Quantize(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::Dequantize(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::IsNan(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::IsInf(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::GridSample2d(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::Tan(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::Cosh(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::Sinh(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::ArcCos(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::ArcCosh(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::ArcSin(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::ArcSinh(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::ArcTan(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::ArcTanh(repr) => Box::new([&repr.out].into_iter()),
FloatOperationIr::ArcTan2(repr) => Box::new([&repr.out].into_iter()),
}
}
fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
let mut output = Vec::new();
match self {
FloatOperationIr::Matmul(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
repr.rhs.mark_read_only(nodes, &mut output);
}
FloatOperationIr::Cross(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
repr.rhs.mark_read_only(nodes, &mut output);
}
FloatOperationIr::Random(_) => {}
FloatOperationIr::Exp(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
FloatOperationIr::Log(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
FloatOperationIr::Log1p(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
FloatOperationIr::Erf(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
FloatOperationIr::Recip(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
FloatOperationIr::PowfScalar(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
}
FloatOperationIr::Sqrt(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
FloatOperationIr::Cos(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
FloatOperationIr::Sin(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
FloatOperationIr::Tanh(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
FloatOperationIr::Round(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
FloatOperationIr::Floor(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
FloatOperationIr::Ceil(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
FloatOperationIr::Trunc(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
FloatOperationIr::Quantize(repr) => {
repr.tensor.mark_read_only(nodes, &mut output);
repr.qparams.scales.mark_read_only(nodes, &mut output);
}
FloatOperationIr::Dequantize(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
FloatOperationIr::IntoInt(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
FloatOperationIr::IsNan(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
FloatOperationIr::IsInf(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
FloatOperationIr::GridSample2d(repr) => {
repr.tensor.mark_read_only(nodes, &mut output);
repr.grid.mark_read_only(nodes, &mut output);
}
FloatOperationIr::Tan(repr) => repr.input.mark_read_only(nodes, &mut output),
FloatOperationIr::Cosh(repr) => repr.input.mark_read_only(nodes, &mut output),
FloatOperationIr::Sinh(repr) => repr.input.mark_read_only(nodes, &mut output),
FloatOperationIr::ArcCos(repr) => repr.input.mark_read_only(nodes, &mut output),
FloatOperationIr::ArcCosh(repr) => repr.input.mark_read_only(nodes, &mut output),
FloatOperationIr::ArcSin(repr) => repr.input.mark_read_only(nodes, &mut output),
FloatOperationIr::ArcSinh(repr) => repr.input.mark_read_only(nodes, &mut output),
FloatOperationIr::ArcTan(repr) => repr.input.mark_read_only(nodes, &mut output),
FloatOperationIr::ArcTanh(repr) => repr.input.mark_read_only(nodes, &mut output),
FloatOperationIr::ArcTan2(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
repr.rhs.mark_read_only(nodes, &mut output);
}
};
output
}
}
impl IntOperationIr {
fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
match self {
IntOperationIr::Matmul(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
IntOperationIr::IntoFloat(repr) => Box::new([&repr.input].into_iter()),
IntOperationIr::BitwiseAnd(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
IntOperationIr::BitwiseAndScalar(repr) => Box::new([&repr.lhs].into_iter()),
IntOperationIr::BitwiseOr(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
IntOperationIr::BitwiseOrScalar(repr) => Box::new([&repr.lhs].into_iter()),
IntOperationIr::BitwiseXor(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
IntOperationIr::BitwiseXorScalar(repr) => Box::new([&repr.lhs].into_iter()),
IntOperationIr::BitwiseNot(repr) => Box::new([&repr.input].into_iter()),
IntOperationIr::BitwiseLeftShift(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
IntOperationIr::BitwiseLeftShiftScalar(repr) => Box::new([&repr.lhs].into_iter()),
IntOperationIr::BitwiseRightShift(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
IntOperationIr::BitwiseRightShiftScalar(repr) => Box::new([&repr.lhs].into_iter()),
}
}
fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
match self {
IntOperationIr::Matmul(repr) => Box::new([&repr.out].into_iter()),
IntOperationIr::IntoFloat(repr) => Box::new([&repr.out].into_iter()),
IntOperationIr::BitwiseAnd(repr) => Box::new([&repr.out].into_iter()),
IntOperationIr::BitwiseAndScalar(repr) => Box::new([&repr.out].into_iter()),
IntOperationIr::BitwiseOr(repr) => Box::new([&repr.out].into_iter()),
IntOperationIr::BitwiseOrScalar(repr) => Box::new([&repr.out].into_iter()),
IntOperationIr::BitwiseXor(repr) => Box::new([&repr.out].into_iter()),
IntOperationIr::BitwiseXorScalar(repr) => Box::new([&repr.out].into_iter()),
IntOperationIr::BitwiseNot(repr) => Box::new([&repr.out].into_iter()),
IntOperationIr::BitwiseLeftShift(repr) => Box::new([&repr.out].into_iter()),
IntOperationIr::BitwiseLeftShiftScalar(repr) => Box::new([&repr.out].into_iter()),
IntOperationIr::BitwiseRightShift(repr) => Box::new([&repr.out].into_iter()),
IntOperationIr::BitwiseRightShiftScalar(repr) => Box::new([&repr.out].into_iter()),
}
}
fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
let mut output = Vec::new();
match self {
IntOperationIr::Matmul(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
repr.rhs.mark_read_only(nodes, &mut output);
}
IntOperationIr::IntoFloat(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
IntOperationIr::BitwiseAnd(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
repr.rhs.mark_read_only(nodes, &mut output);
}
IntOperationIr::BitwiseAndScalar(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
}
IntOperationIr::BitwiseOr(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
repr.rhs.mark_read_only(nodes, &mut output);
}
IntOperationIr::BitwiseOrScalar(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
}
IntOperationIr::BitwiseXor(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
repr.rhs.mark_read_only(nodes, &mut output);
}
IntOperationIr::BitwiseXorScalar(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
}
IntOperationIr::BitwiseNot(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
IntOperationIr::BitwiseLeftShift(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
repr.rhs.mark_read_only(nodes, &mut output);
}
IntOperationIr::BitwiseLeftShiftScalar(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
}
IntOperationIr::BitwiseRightShift(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
repr.rhs.mark_read_only(nodes, &mut output);
}
IntOperationIr::BitwiseRightShiftScalar(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
}
};
output
}
}
impl BoolOperationIr {
fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
match self {
BoolOperationIr::IntoFloat(repr) => Box::new([&repr.input].into_iter()),
BoolOperationIr::IntoInt(repr) => Box::new([&repr.input].into_iter()),
BoolOperationIr::Not(repr) => Box::new([&repr.input].into_iter()),
BoolOperationIr::And(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
BoolOperationIr::Or(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
}
}
fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
match self {
BoolOperationIr::IntoFloat(repr) => Box::new([&repr.out].into_iter()),
BoolOperationIr::IntoInt(repr) => Box::new([&repr.out].into_iter()),
BoolOperationIr::Not(repr) => Box::new([&repr.out].into_iter()),
BoolOperationIr::And(repr) => Box::new([&repr.out].into_iter()),
BoolOperationIr::Or(repr) => Box::new([&repr.out].into_iter()),
}
}
fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
let mut output = Vec::new();
match self {
BoolOperationIr::IntoFloat(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
BoolOperationIr::IntoInt(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
BoolOperationIr::Not(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
BoolOperationIr::And(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
repr.rhs.mark_read_only(nodes, &mut output);
}
BoolOperationIr::Or(repr) => {
repr.lhs.mark_read_only(nodes, &mut output);
repr.rhs.mark_read_only(nodes, &mut output);
}
};
output
}
}
impl ModuleOperationIr {
fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
match self {
ModuleOperationIr::Embedding(repr) => {
Box::new([&repr.weights, &repr.indices].into_iter())
}
ModuleOperationIr::EmbeddingBackward(repr) => {
Box::new([&repr.weights, &repr.out_grad, &repr.indices].into_iter())
}
ModuleOperationIr::Conv1d(repr) => {
if let Some(bias) = &repr.bias {
Box::new([&repr.x, &repr.weight, bias].into_iter())
} else {
Box::new([&repr.x, &repr.weight].into_iter())
}
}
ModuleOperationIr::Conv2d(repr) => {
if let Some(bias) = &repr.bias {
Box::new([&repr.x, &repr.weight, bias].into_iter())
} else {
Box::new([&repr.x, &repr.weight].into_iter())
}
}
ModuleOperationIr::Conv3d(repr) => {
if let Some(bias) = &repr.bias {
Box::new([&repr.x, &repr.weight, bias].into_iter())
} else {
Box::new([&repr.x, &repr.weight].into_iter())
}
}
ModuleOperationIr::DeformableConv2d(repr) => match (&repr.mask, &repr.bias) {
(Some(mask), Some(bias)) => {
Box::new([&repr.x, &repr.offset, &repr.weight, mask, bias].into_iter())
}
(Some(mask), None) => {
Box::new([&repr.x, &repr.offset, &repr.weight, mask].into_iter())
}
(None, Some(bias)) => {
Box::new([&repr.x, &repr.offset, &repr.weight, bias].into_iter())
}
(None, None) => Box::new([&repr.x, &repr.offset, &repr.weight].into_iter()),
},
ModuleOperationIr::DeformableConv2dBackward(repr) => match (&repr.mask, &repr.bias) {
(Some(mask), Some(bias)) => Box::new(
[
&repr.x,
&repr.offset,
&repr.weight,
&repr.out_grad,
mask,
bias,
]
.into_iter(),
),
(Some(mask), None) => Box::new(
[&repr.x, &repr.offset, &repr.weight, &repr.out_grad, mask].into_iter(),
),
(None, Some(bias)) => Box::new(
[&repr.x, &repr.offset, &repr.weight, &repr.out_grad, bias].into_iter(),
),
(None, None) => {
Box::new([&repr.x, &repr.offset, &repr.weight, &repr.out_grad].into_iter())
}
},
ModuleOperationIr::ConvTranspose1d(repr) => {
if let Some(bias) = &repr.bias {
Box::new([&repr.x, &repr.weight, bias].into_iter())
} else {
Box::new([&repr.x, &repr.weight].into_iter())
}
}
ModuleOperationIr::ConvTranspose2d(repr) => {
if let Some(bias) = &repr.bias {
Box::new([&repr.x, &repr.weight, bias].into_iter())
} else {
Box::new([&repr.x, &repr.weight].into_iter())
}
}
ModuleOperationIr::ConvTranspose3d(repr) => {
if let Some(bias) = &repr.bias {
Box::new([&repr.x, &repr.weight, bias].into_iter())
} else {
Box::new([&repr.x, &repr.weight].into_iter())
}
}
ModuleOperationIr::AvgPool1d(repr) => Box::new([&repr.x].into_iter()),
ModuleOperationIr::AvgPool2d(repr) => Box::new([&repr.x].into_iter()),
ModuleOperationIr::AvgPool1dBackward(repr) => {
Box::new([&repr.x, &repr.grad].into_iter())
}
ModuleOperationIr::AvgPool2dBackward(repr) => {
Box::new([&repr.x, &repr.grad].into_iter())
}
ModuleOperationIr::AdaptiveAvgPool1d(repr) => Box::new([&repr.x].into_iter()),
ModuleOperationIr::AdaptiveAvgPool2d(repr) => Box::new([&repr.x].into_iter()),
ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => {
Box::new([&repr.x, &repr.grad].into_iter())
}
ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => {
Box::new([&repr.x, &repr.grad].into_iter())
}
ModuleOperationIr::MaxPool1d(repr) => Box::new([&repr.x].into_iter()),
ModuleOperationIr::MaxPool1dWithIndices(repr) => Box::new([&repr.x].into_iter()),
ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => {
Box::new([&repr.x, &repr.indices, &repr.grad].into_iter())
}
ModuleOperationIr::MaxPool2d(repr) => Box::new([&repr.x].into_iter()),
ModuleOperationIr::MaxPool2dWithIndices(repr) => Box::new([&repr.x].into_iter()),
ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => {
Box::new([&repr.x, &repr.indices, &repr.grad].into_iter())
}
ModuleOperationIr::Interpolate(repr) => Box::new([&repr.x].into_iter()),
ModuleOperationIr::InterpolateBackward(repr) => {
Box::new([&repr.x, &repr.grad].into_iter())
}
}
}
fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
match self {
ModuleOperationIr::Embedding(repr) => Box::new([&repr.out].into_iter()),
ModuleOperationIr::EmbeddingBackward(repr) => Box::new([&repr.out].into_iter()),
ModuleOperationIr::Conv1d(repr) => Box::new([&repr.out].into_iter()),
ModuleOperationIr::Conv2d(repr) => Box::new([&repr.out].into_iter()),
ModuleOperationIr::Conv3d(repr) => Box::new([&repr.out].into_iter()),
ModuleOperationIr::DeformableConv2d(repr) => Box::new([&repr.out].into_iter()),
ModuleOperationIr::DeformableConv2dBackward(repr) => {
match (&repr.mask_grad, &repr.bias_grad) {
(Some(mask_grad), Some(bias_grad)) => Box::new(
[
&repr.input_grad,
&repr.offset_grad,
&repr.weight_grad,
mask_grad,
bias_grad,
]
.into_iter(),
),
(Some(mask_grad), None) => Box::new(
[
&repr.input_grad,
&repr.offset_grad,
&repr.weight_grad,
mask_grad,
]
.into_iter(),
),
(None, Some(bias_grad)) => Box::new(
[
&repr.input_grad,
&repr.offset_grad,
&repr.weight_grad,
bias_grad,
]
.into_iter(),
),
(None, None) => Box::new(
[&repr.input_grad, &repr.offset_grad, &repr.weight_grad].into_iter(),
),
}
}
ModuleOperationIr::ConvTranspose1d(repr) => Box::new([&repr.out].into_iter()),
ModuleOperationIr::ConvTranspose2d(repr) => Box::new([&repr.out].into_iter()),
ModuleOperationIr::ConvTranspose3d(repr) => Box::new([&repr.out].into_iter()),
ModuleOperationIr::AvgPool1d(repr) => Box::new([&repr.out].into_iter()),
ModuleOperationIr::AvgPool2d(repr) => Box::new([&repr.out].into_iter()),
ModuleOperationIr::AvgPool1dBackward(repr) => Box::new([&repr.out].into_iter()),
ModuleOperationIr::AvgPool2dBackward(repr) => Box::new([&repr.out].into_iter()),
ModuleOperationIr::AdaptiveAvgPool1d(repr) => Box::new([&repr.out].into_iter()),
ModuleOperationIr::AdaptiveAvgPool2d(repr) => Box::new([&repr.out].into_iter()),
ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => Box::new([&repr.out].into_iter()),
ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => Box::new([&repr.out].into_iter()),
ModuleOperationIr::MaxPool1d(repr) => Box::new([&repr.out].into_iter()),
ModuleOperationIr::MaxPool1dWithIndices(repr) => {
Box::new([&repr.out, &repr.out_indices].into_iter())
}
ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => {
Box::new([&repr.out].into_iter())
}
ModuleOperationIr::MaxPool2d(repr) => Box::new([&repr.out].into_iter()),
ModuleOperationIr::MaxPool2dWithIndices(repr) => {
Box::new([&repr.out, &repr.out_indices].into_iter())
}
ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => {
Box::new([&repr.out].into_iter())
}
ModuleOperationIr::Interpolate(repr) => Box::new([&repr.out].into_iter()),
ModuleOperationIr::InterpolateBackward(repr) => Box::new([&repr.out].into_iter()),
}
}
fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
let mut output = Vec::new();
match self {
ModuleOperationIr::Embedding(repr) => {
repr.weights.mark_read_only(nodes, &mut output);
repr.indices.mark_read_only(nodes, &mut output);
}
ModuleOperationIr::EmbeddingBackward(repr) => {
repr.weights.mark_read_only(nodes, &mut output);
repr.out_grad.mark_read_only(nodes, &mut output);
repr.indices.mark_read_only(nodes, &mut output);
}
ModuleOperationIr::Conv1d(repr) => {
repr.x.mark_read_only(nodes, &mut output);
repr.weight.mark_read_only(nodes, &mut output);
if let Some(bias) = &mut repr.bias {
bias.mark_read_only(nodes, &mut output);
}
}
ModuleOperationIr::Conv2d(repr) => {
repr.x.mark_read_only(nodes, &mut output);
repr.weight.mark_read_only(nodes, &mut output);
if let Some(bias) = &mut repr.bias {
bias.mark_read_only(nodes, &mut output);
}
}
ModuleOperationIr::Conv3d(repr) => {
repr.x.mark_read_only(nodes, &mut output);
repr.weight.mark_read_only(nodes, &mut output);
if let Some(bias) = &mut repr.bias {
bias.mark_read_only(nodes, &mut output);
}
}
ModuleOperationIr::DeformableConv2d(repr) => {
repr.x.mark_read_only(nodes, &mut output);
repr.weight.mark_read_only(nodes, &mut output);
repr.offset.mark_read_only(nodes, &mut output);
match (&mut repr.mask, &mut repr.bias) {
(Some(mask), Some(bias)) => {
mask.mark_read_only(nodes, &mut output);
bias.mark_read_only(nodes, &mut output);
}
(Some(mask), None) => {
mask.mark_read_only(nodes, &mut output);
}
(None, Some(bias)) => {
bias.mark_read_only(nodes, &mut output);
}
(None, None) => {}
};
}
ModuleOperationIr::DeformableConv2dBackward(repr) => {
repr.x.mark_read_only(nodes, &mut output);
repr.weight.mark_read_only(nodes, &mut output);
repr.offset.mark_read_only(nodes, &mut output);
repr.out_grad.mark_read_only(nodes, &mut output);
if let Some(mask) = repr.mask.as_mut() {
mask.mark_read_only(nodes, &mut output);
}
if let Some(bias) = repr.bias.as_mut() {
bias.mark_read_only(nodes, &mut output);
}
}
ModuleOperationIr::ConvTranspose1d(repr) => {
repr.x.mark_read_only(nodes, &mut output);
repr.weight.mark_read_only(nodes, &mut output);
if let Some(bias) = &mut repr.bias {
bias.mark_read_only(nodes, &mut output);
}
}
ModuleOperationIr::ConvTranspose2d(repr) => {
repr.x.mark_read_only(nodes, &mut output);
repr.weight.mark_read_only(nodes, &mut output);
if let Some(bias) = &mut repr.bias {
bias.mark_read_only(nodes, &mut output);
}
}
ModuleOperationIr::ConvTranspose3d(repr) => {
repr.x.mark_read_only(nodes, &mut output);
repr.weight.mark_read_only(nodes, &mut output);
if let Some(bias) = &mut repr.bias {
bias.mark_read_only(nodes, &mut output);
}
}
ModuleOperationIr::AvgPool1d(repr) => {
repr.x.mark_read_only(nodes, &mut output);
}
ModuleOperationIr::AvgPool2d(repr) => {
repr.x.mark_read_only(nodes, &mut output);
}
ModuleOperationIr::AvgPool1dBackward(repr) => {
repr.x.mark_read_only(nodes, &mut output);
repr.grad.mark_read_only(nodes, &mut output);
}
ModuleOperationIr::AvgPool2dBackward(repr) => {
repr.x.mark_read_only(nodes, &mut output);
repr.grad.mark_read_only(nodes, &mut output);
}
ModuleOperationIr::AdaptiveAvgPool1d(repr) => {
repr.x.mark_read_only(nodes, &mut output);
}
ModuleOperationIr::AdaptiveAvgPool2d(repr) => {
repr.x.mark_read_only(nodes, &mut output);
}
ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => {
repr.x.mark_read_only(nodes, &mut output);
repr.grad.mark_read_only(nodes, &mut output);
}
ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => {
repr.x.mark_read_only(nodes, &mut output);
repr.grad.mark_read_only(nodes, &mut output);
}
ModuleOperationIr::MaxPool1d(repr) => {
repr.x.mark_read_only(nodes, &mut output);
}
ModuleOperationIr::MaxPool1dWithIndices(repr) => {
repr.x.mark_read_only(nodes, &mut output);
}
ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => {
repr.x.mark_read_only(nodes, &mut output);
repr.grad.mark_read_only(nodes, &mut output);
}
ModuleOperationIr::MaxPool2d(repr) => {
repr.x.mark_read_only(nodes, &mut output);
}
ModuleOperationIr::MaxPool2dWithIndices(repr) => {
repr.x.mark_read_only(nodes, &mut output);
}
ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => {
repr.x.mark_read_only(nodes, &mut output);
repr.grad.mark_read_only(nodes, &mut output);
}
ModuleOperationIr::Interpolate(repr) => {
repr.x.mark_read_only(nodes, &mut output);
}
ModuleOperationIr::InterpolateBackward(repr) => {
repr.x.mark_read_only(nodes, &mut output);
repr.grad.mark_read_only(nodes, &mut output);
}
};
output
}
}
impl InitOperationIr {
fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
Box::new([].into_iter())
}
fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
Box::new([&self.out].into_iter())
}
}
impl TensorIr {
fn mark_read_only(&mut self, nodes: &[TensorId], output: &mut Vec<TensorIr>) {
if self.status == TensorStatus::ReadWrite && nodes.contains(&self.id) {
output.push(self.clone());
self.status = TensorStatus::ReadOnly;
}
}
}
impl core::hash::Hash for RandomOpIr {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.out.hash(state);
match self.distribution {
Distribution::Default => 1u8.hash(state),
Distribution::Bernoulli(_) => 2u8.hash(state),
Distribution::Uniform(_, _) => 3u8.hash(state),
Distribution::Normal(_, _) => 4u8.hash(state),
}
}
}
pub trait OperationOutput<O> {
fn output(self) -> O;
fn outputs<const N: usize>(self) -> [O; N];
}
impl<O: core::fmt::Debug> OperationOutput<O> for Vec<O> {
fn output(self) -> O {
let [tensor] = self.outputs();
tensor
}
fn outputs<const N: usize>(self) -> [O; N] {
self.try_into().unwrap()
}
}