Skip to main content

burn_ir/
operation.rs

1use burn_backend::ops::AttentionModuleOptions;
2use burn_backend::tensor::IndexingUpdateOp;
3use core::hash::Hash;
4use serde::{Deserialize, Serialize};
5
6use alloc::borrow::ToOwned;
7use alloc::boxed::Box;
8use alloc::{string::String, vec::Vec};
9
10use burn_backend::{
11    DType, Distribution, Slice,
12    ops::{
13        ConvOptions, ConvTransposeOptions, DeformConvOptions, GridSampleOptions,
14        GridSamplePaddingMode, InterpolateMode, InterpolateOptions,
15    },
16    quantization::QuantScheme,
17};
18
19use crate::{ScalarIr, TensorId, TensorIr, TensorStatus};
20
21/// Custom operation in fusion stream, declaring its inputs and outputs.
22#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
23pub struct CustomOpIr {
24    /// Unique identifier of the operation.
25    pub id: String,
26    /// Input tensors used in the custom operation.
27    pub inputs: Vec<TensorIr>,
28    /// Output tensors used in the custom operation.
29    pub outputs: Vec<TensorIr>,
30}
31
32impl CustomOpIr {
33    /// Create a new custom operation intermediate representation.
34    pub fn new(id: &'static str, inputs: &[TensorIr], outputs: &[TensorIr]) -> Self {
35        Self {
36            id: id.to_owned(),
37            inputs: inputs.to_vec(),
38            outputs: outputs.to_vec(),
39        }
40    }
41
42    /// Cast the intermediate representation, and get the in and output tensors.
43    pub fn as_fixed<const N_IN: usize, const N_OUT: usize>(
44        &self,
45    ) -> (&[TensorIr; N_IN], &[TensorIr; N_OUT]) {
46        (
47            self.inputs.as_slice().try_into().expect(
48                "Wrong number of inputs expected (expected {D}, is {}), check your implementation",
49            ),
50            self.outputs.as_slice().try_into().expect(
51                "Wrong number of outputs expected (expected {D}, is {}), check your implementation",
52            ),
53        )
54    }
55
56    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
57        Box::new(self.inputs.iter())
58    }
59
60    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
61        Box::new(self.outputs.iter())
62    }
63}
64
65/// Describe all tensor operations possible.
66#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
67#[allow(clippy::large_enum_variant)]
68pub enum OperationIr {
69    /// Basic operation on a float tensor.
70    BaseFloat(BaseOperationIr),
71    /// Basic operation on an int tensor.
72    BaseInt(BaseOperationIr),
73    /// Basic operation on a bool tensor.
74    BaseBool(BaseOperationIr),
75    /// Numeric operation on a float tensor.
76    NumericFloat(DType, NumericOperationIr),
77    /// Numeric operation on an int tensor.
78    NumericInt(DType, NumericOperationIr),
79    /// Operation specific to a bool tensor.
80    Bool(BoolOperationIr),
81    /// Operation specific to an int tensor.
82    Int(IntOperationIr),
83    /// Operation specific to a float tensor.
84    Float(DType, FloatOperationIr),
85    /// Module operation.
86    Module(ModuleOperationIr),
87    /// Initialize operation.
88    Init(InitOperationIr),
89    /// A custom operation.
90    Custom(CustomOpIr),
91    /// A tensor is dropped.
92    Drop(TensorIr),
93    #[cfg(feature = "distributed")]
94    /// Operation specific to a distributed tensor.
95    Distributed(DistributedOperationIr),
96}
97
98/// Operation intermediate representation specific to a float tensor.
99#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
100pub enum FloatOperationIr {
101    /// Operation corresponding to [exp](burn_backend::ops::FloatTensorOps::float_exp).
102    Exp(UnaryOpIr),
103    /// Operation corresponding to [log](burn_backend::ops::FloatTensorOps::float_log).
104    Log(UnaryOpIr),
105    /// Operation corresponding to [log1p](burn_backend::ops::FloatTensorOps::float_log1p).
106    Log1p(UnaryOpIr),
107    /// Operation corresponding to [erf](burn_backend::ops::FloatTensorOps::float_erf).
108    Erf(UnaryOpIr),
109    /// Operation corresponding to [powf_scalar](burn_backend::ops::FloatTensorOps::float_powf_scalar).
110    PowfScalar(ScalarOpIr),
111    /// Operation corresponding to [sqrt](burn_backend::ops::FloatTensorOps::float_sqrt).
112    Sqrt(UnaryOpIr),
113    /// Operation corresponding to [cos](burn_backend::ops::FloatTensorOps::float_cos).
114    Cos(UnaryOpIr),
115    /// Operation corresponding to [cosh](burn_backend::ops::FloatTensorOps::float_cosh).
116    Cosh(UnaryOpIr),
117    /// Operation corresponding to [sin](burn_backend::ops::FloatTensorOps::float_sin).
118    Sin(UnaryOpIr),
119    /// Operation corresponding to [sin](burn_backend::ops::FloatTensorOps::float_sinh).
120    Sinh(UnaryOpIr),
121    /// Operation corresponding to [tan](burn_backend::ops::FloatTensorOps::float_tan).
122    Tan(UnaryOpIr),
123    /// Operation corresponding to [tanh](burn_backend::ops::FloatTensorOps::float_tanh).
124    Tanh(UnaryOpIr),
125    /// Operation corresponding to [acos](burn_backend::ops::FloatTensorOps::float_acos).
126    ArcCos(UnaryOpIr),
127    /// Operation corresponding to [acosh](burn_backend::ops::FloatTensorOps::float_acosh).
128    ArcCosh(UnaryOpIr),
129    /// Operation corresponding to [asin](burn_backend::ops::FloatTensorOps::float_asin).
130    ArcSin(UnaryOpIr),
131    /// Operation corresponding to [asinh](burn_backend::ops::FloatTensorOps::float_asinh).
132    ArcSinh(UnaryOpIr),
133    /// Operation corresponding to [atan](burn_backend::ops::FloatTensorOps::float_atan).
134    ArcTan(UnaryOpIr),
135    /// Operation corresponding to [atanh](burn_backend::ops::FloatTensorOps::float_atanh).
136    ArcTanh(UnaryOpIr),
137    /// Operation corresponding to [atan2](burn_backend::ops::FloatTensorOps::float_atan2).
138    ArcTan2(BinaryOpIr),
139    /// Operation corresponding to [round](burn_backend::ops::FloatTensorOps::float_round).
140    Round(UnaryOpIr),
141    /// Operation corresponding to [floor](burn_backend::ops::FloatTensorOps::float_floor).
142    Floor(UnaryOpIr),
143    /// Operation corresponding to [ceil](burn_backend::ops::FloatTensorOps::float_ceil).
144    Ceil(UnaryOpIr),
145    /// Operation corresponding to [trunc](burn_backend::ops::FloatTensorOps::float_trunc).
146    Trunc(UnaryOpIr),
147    /// Operation corresponding to [into_int](burn_backend::ops::FloatTensorOps::float_into_int).
148    IntoInt(CastOpIr),
149    /// Operation corresponding to [matmul](burn_backend::ops::FloatTensorOps::float_matmul).
150    Matmul(MatmulOpIr),
151    /// Operation corresponding to [cross](burn_backend::ops::FloatTensorOps::float_cross).
152    Cross(CrossOpIr),
153    /// Operation corresponding to [random](burn_backend::ops::FloatTensorOps::float_random).
154    Random(RandomOpIr),
155    /// Operation corresponding to [recip](burn_backend::ops::FloatTensorOps::float_recip).
156    Recip(UnaryOpIr),
157    /// Operation corresponding to [is_nan](burn_backend::ops::FloatTensorOps::float_is_nan).
158    IsNan(UnaryOpIr),
159    /// Operation corresponding to [is_nan](burn_backend::ops::FloatTensorOps::float_is_inf).
160    IsInf(UnaryOpIr),
161    /// Operation corresponding to [quantize](burn_backend::ops::QTensorOps::quantize).
162    Quantize(QuantizeOpIr),
163    /// Operation corresponding to [dequantize](burn_backend::ops::QTensorOps::dequantize).
164    Dequantize(DequantizeOpIr),
165    /// Operation corresponding to [grid_sample_2d](burn_backend::ops::FloatTensorOps::float_grid_sample_2d).
166    GridSample2d(GridSample2dOpIr),
167    /// Operation corresponding to [powf](burn_backend::ops::FloatTensorOps::float_powi).
168    Powf(BinaryOpIr),
169}
170
171/// Operation intermediate representation specific to module.
172#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
173pub enum ModuleOperationIr {
174    /// Operation corresponding to [embedding](burn_backend::ops::ModuleOps::embedding).
175    Embedding(EmbeddingOpIr),
176    /// Operation corresponding to [embedding_backward](burn_backend::ops::ModuleOps::embedding_backward).
177    EmbeddingBackward(EmbeddingBackwardOpIr),
178    /// Operation corresponding to [linear](burn_backend::ops::ModuleOps::linear).
179    Linear(LinearOpIr),
180    /// Operation corresponding to [linear_x_backward](burn_backend::ops::ModuleOps::linear_x_backward).
181    LinearXBackward(LinearXBackwardOpIr),
182    /// Operation corresponding to [linear_weight_backward](burn_backend::ops::ModuleOps::linear_weight_backward).
183    LinearWeightBackward(LinearWeightBackwardOpIr),
184    /// Operation corresponding to [linear_bias_backward](burn_backend::ops::ModuleOps::linear_bias_backward).
185    LinearBiasBackward(LinearBiasBackwardOpIr),
186    /// Operation corresponding to [conv1d](burn_backend::ops::ModuleOps::conv1d).
187    Conv1d(Conv1dOpIr),
188    /// Operation corresponding to [conv1d_x_backward](burn_backend::ops::ModuleOps::conv1d_x_backward).
189    Conv1dXBackward(Conv1dXBackwardOpIr),
190    /// Operation corresponding to [conv1d_weight_backward](burn_backend::ops::ModuleOps::conv1d_weight_backward).
191    Conv1dWeightBackward(Conv1dWeightBackwardOpIr),
192    /// Operation corresponding to [conv1d_bias_backward](burn_backend::ops::ModuleOps::conv1d_bias_backward).
193    Conv1dBiasBackward(Conv1dBiasBackwardOpIr),
194    /// Operation corresponding to [conv2d](burn_backend::ops::ModuleOps::conv2d).
195    Conv2d(Conv2dOpIr),
196    /// Operation corresponding to [conv2d_x_backward](burn_backend::ops::ModuleOps::conv2d_x_backward).
197    Conv2dXBackward(Conv2dXBackwardOpIr),
198    /// Operation corresponding to [conv2d_weight_backward](burn_backend::ops::ModuleOps::conv2d_weight_backward).
199    Conv2dWeightBackward(Conv2dWeightBackwardOpIr),
200    /// Operation corresponding to [conv2d_bias_backward](burn_backend::ops::ModuleOps::conv2d_bias_backward).
201    Conv2dBiasBackward(Conv2dBiasBackwardOpIr),
202    /// Operation corresponding to [conv3d](burn_backend::ops::ModuleOps::conv3d).
203    Conv3d(Conv3dOpIr),
204    /// Operation corresponding to [conv3d_x_backward](burn_backend::ops::ModuleOps::conv3d_x_backward).
205    Conv3dXBackward(Conv3dXBackwardOpIr),
206    /// Operation corresponding to [conv3d_weight_backward](burn_backend::ops::ModuleOps::conv3d_weight_backward).
207    Conv3dWeightBackward(Conv3dWeightBackwardOpIr),
208    /// Operation corresponding to [conv3d_bias_backward](burn_backend::ops::ModuleOps::conv3d_bias_backward).
209    Conv3dBiasBackward(Conv3dBiasBackwardOpIr),
210    /// Operation corresponding to [deform_conv2d](burn_backend::ops::ModuleOps::deform_conv2d)
211    DeformableConv2d(Box<DeformConv2dOpIr>),
212    /// Operation corresponding to [deform_conv2d_backward](burn_backend::ops::ModuleOps::deform_conv2d_backward)
213    DeformableConv2dBackward(Box<DeformConv2dBackwardOpIr>),
214    /// Operation corresponding to [conv transpose 1d](burn_backend::ops::ModuleOps::conv_transpose1d).
215    ConvTranspose1d(ConvTranspose1dOpIr),
216    /// Operation corresponding to [conv transpose 2d](burn_backend::ops::ModuleOps::conv_transpose2d).
217    ConvTranspose2d(ConvTranspose2dOpIr),
218    /// Operation corresponding to [conv transpose 3d](burn_backend::ops::ModuleOps::conv_transpose3d).
219    ConvTranspose3d(ConvTranspose3dOpIr),
220    /// Operation corresponding to [avg pool 1d](burn_backend::ops::ModuleOps::avg_pool1d).
221    AvgPool1d(AvgPool1dOpIr),
222    /// Operation corresponding to [avg pool 2d](burn_backend::ops::ModuleOps::avg_pool2d).
223    AvgPool2d(AvgPool2dOpIr),
224    /// Operation corresponding to
225    /// [avg pool 1d backward](burn_backend::ops::ModuleOps::avg_pool1d_backward).
226    AvgPool1dBackward(AvgPool1dBackwardOpIr),
227    /// Operation corresponding to
228    /// [avg pool 2d backward](burn_backend::ops::ModuleOps::avg_pool2d_backward).
229    AvgPool2dBackward(AvgPool2dBackwardOpIr),
230    /// Operation corresponding to
231    /// [adaptive avg pool 1d](burn_backend::ops::ModuleOps::adaptive_avg_pool1d).
232    AdaptiveAvgPool1d(AdaptiveAvgPool1dOpIr),
233    /// Operation corresponding to
234    /// [adaptive avg pool 2d](burn_backend::ops::ModuleOps::adaptive_avg_pool2d).
235    AdaptiveAvgPool2d(AdaptiveAvgPool2dOpIr),
236    /// Operation corresponding to
237    /// [adaptive avg pool 1d backward](burn_backend::ops::ModuleOps::adaptive_avg_pool1d_backward).
238    AdaptiveAvgPool1dBackward(AdaptiveAvgPool1dBackwardOpIr),
239    /// Operation corresponding to
240    /// [adaptive avg pool 2d backward](burn_backend::ops::ModuleOps::adaptive_avg_pool2d_backward).
241    AdaptiveAvgPool2dBackward(AdaptiveAvgPool2dBackwardOpIr),
242    /// Operation corresponding to
243    /// [max pool 1d](burn_backend::ops::ModuleOps::max_pool1d).
244    MaxPool1d(MaxPool1dOpIr),
245    /// Operation corresponding to
246    /// [max pool 1d with indices](burn_backend::ops::ModuleOps::max_pool1d_with_indices).
247    MaxPool1dWithIndices(MaxPool1dWithIndicesOpIr),
248    /// Operation corresponding to
249    /// [max pool 1d with indices backward](burn_backend::ops::ModuleOps::max_pool1d_with_indices_backward).
250    MaxPool1dWithIndicesBackward(MaxPool1dWithIndicesBackwardOpIr),
251    /// Operation corresponding to
252    /// [max pool 2d](burn_backend::ops::ModuleOps::max_pool1d).
253    MaxPool2d(MaxPool2dOpIr),
254    /// Operation corresponding to
255    /// [max pool 2d with indices](burn_backend::ops::ModuleOps::max_pool2d_with_indices).
256    MaxPool2dWithIndices(MaxPool2dWithIndicesOpIr),
257    /// Operation corresponding to
258    /// [max pool 2d with indices backward](burn_backend::ops::ModuleOps::max_pool2d_with_indices_backward).
259    MaxPool2dWithIndicesBackward(MaxPool2dWithIndicesBackwardOpIr),
260    /// Operation corresponding to [interpolate](burn_backend::ops::ModuleOps::interpolate).
261    Interpolate(InterpolateOpIr),
262    /// Operation corresponding to [interpolate backward](burn_backend::ops::ModuleOps::interpolate_backward).
263    InterpolateBackward(InterpolateBackwardOpIr),
264    /// Operation corresponding to [rfft](burn_backend::ops::ModuleOps::rfft)
265    Rfft(RfftOpIr),
266    /// Operation corresponding to [irfft](burn_backend::ops::ModuleOps::irfft)
267    IRfft(IRfftOpIr),
268    /// Operation corresponding to [attention](burn_backend::ops::ModuleOps::attention).
269    Attention(AttentionOpIr),
270    /// Operation corresponding to [ctc_loss](burn_backend::ops::ModuleOps::ctc_loss).
271    CtcLoss(CtcLossOpIr),
272    /// Operation corresponding to
273    /// [ctc_loss_backward](burn_backend::ops::ModuleOps::ctc_loss_backward).
274    CtcLossBackward(CtcLossBackwardOpIr),
275}
276
277/// Basic operations that can be done on any tensor type.
278#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
279pub enum BaseOperationIr {
280    /// Operation corresponding to:
281    ///
282    /// Float => [reshape](burn_backend::ops::FloatTensorOps::float_reshape).
283    /// Int => [reshape](burn_backend::ops::IntTensorOps::int_reshape).
284    /// Bool => [reshape](burn_backend::ops::BoolTensorOps::bool_reshape).
285    Reshape(ShapeOpIr),
286
287    /// Operation corresponding to:
288    ///
289    /// Float => [swap_dims](burn_backend::ops::FloatTensorOps::float_swap_dims).
290    /// Int => [swap_dims](burn_backend::ops::IntTensorOps::int_swap_dims).
291    /// Bool => [swap_dims](burn_backend::ops::BoolTensorOps::bool_swap_dims).
292    SwapDims(SwapDimsOpIr),
293
294    /// Operation corresponding to:
295    ///
296    /// Float => [permute](burn_backend::ops::FloatTensorOps::float_permute).
297    /// Int => [permute](burn_backend::ops::IntTensorOps::int_permute).
298    /// Bool => [permute](burn_backend::ops::BoolTensorOps::bool_permute).
299    Permute(PermuteOpIr),
300
301    /// Operation corresponding to:
302    /// Float => [flip](burn_backend::ops::FloatTensorOps::float_flip).
303    /// Int => [flip](burn_backend::ops::IntTensorOps::int_flip).
304    /// Bool => [flip](burn_backend::ops::BoolTensorOps::bool_flip).
305    Flip(FlipOpIr),
306
307    /// Operation corresponding to:
308    ///
309    /// Float => [expand](burn_backend::ops::FloatTensorOps::float_expand).
310    /// Int => [expand](burn_backend::ops::IntTensorOps::int_expand).
311    /// Bool => [expand](burn_backend::ops::BoolTensorOps::bool_expand).
312    Expand(ShapeOpIr),
313
314    /// Unfold windows along an axis.
315    ///
316    Unfold(UnfoldOpIr),
317
318    /// Operation corresponding to:
319    ///
320    /// Float => [slice](burn_backend::ops::FloatTensorOps::float_slice).
321    /// Int => [slice](burn_backend::ops::IntTensorOps::int_slice).
322    /// Bool => [slice](burn_backend::ops::BoolTensorOps::bool_slice).
323    Slice(SliceOpIr),
324    /// Operation corresponding to:
325    ///
326    /// Float => [slice assign](burn_backend::ops::FloatTensorOps::float_slice_assign).
327    /// Int => [slice assign](burn_backend::ops::IntTensorOps::int_slice_assign).
328    /// Bool => [slice assign](burn_backend::ops::BoolTensorOps::bool_slice_assign).
329    SliceAssign(SliceAssignOpIr),
330    /// Operation corresponding to:
331    ///
332    /// Float => [select](burn_backend::ops::FloatTensorOps::float_select).
333    /// Int => [select](burn_backend::ops::IntTensorOps::int_select).
334    /// Bool => [select](burn_backend::ops::BoolTensorOps::bool_select).
335    Select(SelectOpIr),
336    /// Operation corresponding to:
337    ///
338    /// Float => [select assign](burn_backend::ops::FloatTensorOps::float_select_add).
339    /// Int => [select assign](burn_backend::ops::IntTensorOps::int_select_add).
340    /// Bool => [select assign](burn_backend::ops::BoolTensorOps::bool_select_or).
341    SelectAssign(SelectAssignOpIr),
342    /// Operation corresponding to:
343    ///
344    /// Float => [mask where](burn_backend::ops::FloatTensorOps::float_mask_where).
345    /// Int => [mask where](burn_backend::ops::IntTensorOps::int_mask_where).
346    /// Bool => [mask where](burn_backend::ops::BoolTensorOps::bool_mask_where).
347    MaskWhere(MaskWhereOpIr),
348    /// Operation corresponding to:
349    ///
350    /// Float => [mask fill](burn_backend::ops::FloatTensorOps::float_mask_fill).
351    /// Int => [mask fill](burn_backend::ops::IntTensorOps::int_mask_fill).
352    /// Bool => [mask fill](burn_backend::ops::BoolTensorOps::bool_mask_fill).
353    MaskFill(MaskFillOpIr),
354    /// Operation corresponding to:
355    ///
356    /// Float => [gather](burn_backend::ops::FloatTensorOps::float_gather).
357    /// Int => [gather](burn_backend::ops::IntTensorOps::int_gather).
358    /// Bool => [gather](burn_backend::ops::BoolTensorOps::bool_gather).
359    Gather(GatherOpIr),
360    /// Operation corresponding to:
361    ///
362    /// Float => [scatter](burn_backend::ops::FloatTensorOps::float_scatter_add).
363    /// Int => [scatter](burn_backend::ops::IntTensorOps::int_scatter_add).
364    /// Bool => [scatter](burn_backend::ops::BoolTensorOps::bool_scatter_or).
365    Scatter(ScatterOpIr),
366    /// Multi-dimensional scatter operation.
367    ScatterNd(ScatterNdOpIr),
368    /// Multi-dimensional gather operation.
369    GatherNd(GatherNdOpIr),
370    /// Operation corresponding to:
371    ///
372    /// Float => [equal](burn_backend::ops::FloatTensorOps::float_equal).
373    /// Int => [equal](burn_backend::ops::IntTensorOps::int_equal).
374    /// Bool => [equal](burn_backend::ops::BoolTensorOps::bool_equal).
375    Equal(BinaryOpIr),
376    /// Operation corresponding to:
377    ///
378    /// Float => [equal elem](burn_backend::ops::FloatTensorOps::float_equal_elem).
379    /// Int => [equal elem](burn_backend::ops::IntTensorOps::int_equal_elem).
380    /// Bool => [equal elem](burn_backend::ops::BoolTensorOps::bool_equal_elem).
381    EqualElem(ScalarOpIr),
382    /// Operation corresponding to:
383    ///
384    /// Float => [repeat dim](burn_backend::ops::FloatTensorOps::float_repeat_dim).
385    /// Int => [repeat dim](burn_backend::ops::IntTensorOps::int_repeat_dim).
386    /// Bool => [repeat dim](burn_backend::ops::BoolTensorOps::bool_repeat_dim).
387    RepeatDim(RepeatDimOpIr),
388    /// Operation corresponding to:
389    ///
390    /// Float => [cat](burn_backend::ops::FloatTensorOps::float_cat).
391    /// Int => [cat](burn_backend::ops::IntTensorOps::int_cat).
392    /// Bool => [cat](burn_backend::ops::BoolTensorOps::bool_cat).
393    Cat(CatOpIr),
394    /// Cast operation, no direct operation and should be supported by fusion backend.
395    Cast(CastOpIr),
396    /// Operation corresponding to:
397    ///
398    /// Float => [empty](burn_backend::ops::FloatTensorOps::float_empty).
399    /// Int => [empty](burn_backend::ops::IntTensorOps::int_empty).
400    /// Bool => [empty](burn_backend::ops::BoolTensorOps::bool_empty).
401    Empty(CreationOpIr),
402    /// Operation corresponding to:
403    ///
404    /// Float => [ones](burn_backend::ops::FloatTensorOps::float_ones).
405    /// Int => [ones](burn_backend::ops::IntTensorOps::int_ones).
406    /// Bool => [ones](burn_backend::ops::BoolTensorOps::bool_ones).
407    Ones(CreationOpIr),
408    /// Operation corresponding to:
409    ///
410    /// Float => [zeros](burn_backend::ops::FloatTensorOps::float_zeros).
411    /// Int => [zeros](burn_backend::ops::IntTensorOps::int_zeros).
412    /// Bool => [zeros](burn_backend::ops::BoolTensorOps::bool_zeros).
413    Zeros(CreationOpIr),
414}
415
416/// Numeric operations on int and float tensors.
417#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
418pub enum NumericOperationIr {
419    /// Operation corresponding to:
420    ///
421    /// Float => [add](burn_backend::ops::FloatTensorOps::float_add).
422    /// Int => [add](burn_backend::ops::IntTensorOps::int_add).
423    Add(BinaryOpIr),
424    /// Operation corresponding to:
425    ///
426    /// Float => [add scalar](burn_backend::ops::FloatTensorOps::float_add_scalar).
427    /// Int => [add scalar](burn_backend::ops::IntTensorOps::int_add_scalar).
428    AddScalar(ScalarOpIr),
429    /// Operation corresponding to:
430    ///
431    /// Float => [sub](burn_backend::ops::FloatTensorOps::float_sub).
432    /// Int => [sub](burn_backend::ops::IntTensorOps::int_sub).
433    Sub(BinaryOpIr),
434    /// Operation corresponding to:
435    ///
436    /// Float => [sub scalar](burn_backend::ops::FloatTensorOps::float_sub_scalar).
437    /// Int => [sub scalar](burn_backend::ops::IntTensorOps::int_sub_scalar).
438    SubScalar(ScalarOpIr),
439    /// Operation corresponding to:
440    ///
441    /// Float => [div](burn_backend::ops::FloatTensorOps::float_div).
442    /// Int => [div](burn_backend::ops::IntTensorOps::int_div).
443    Div(BinaryOpIr),
444    /// Operation corresponding to:
445    ///
446    /// Float => [div scalar](burn_backend::ops::FloatTensorOps::float_div_scalar).
447    /// Int => [div scalar](burn_backend::ops::IntTensorOps::int_div_scalar).
448    DivScalar(ScalarOpIr),
449    /// Operation corresponding to:
450    ///
451    /// Float => [rem](burn_backend::ops::FloatTensorOps::float_remainder).
452    /// Int => [rem](burn_backend::ops::IntTensorOps::int_remainder).
453    Rem(BinaryOpIr),
454    /// Operation corresponding to:
455    ///
456    /// Float => [rem scalar](burn_backend::ops::FloatTensorOps::float_remainder_scalar).
457    /// Int => [rem scalar](burn_backend::ops::IntTensorOps::int_remainder_scalar).
458    RemScalar(ScalarOpIr),
459    /// Operation corresponding to:
460    ///
461    /// Float => [mul](burn_backend::ops::FloatTensorOps::float_mul).
462    /// Int => [mul](burn_backend::ops::IntTensorOps::int_mul).
463    Mul(BinaryOpIr),
464    /// Operation corresponding to:
465    ///
466    /// Float => [mul scalar](burn_backend::ops::FloatTensorOps::float_mul_scalar).
467    /// Int => [mul scalar](burn_backend::ops::IntTensorOps::int_mul_scalar).
468    MulScalar(ScalarOpIr),
469    /// Operation corresponding to:
470    ///
471    /// Float => [abs](burn_backend::ops::FloatTensorOps::float_abs).
472    /// Int => [abs](burn_backend::ops::IntTensorOps::int_abs).
473    Abs(UnaryOpIr),
474    /// Operation corresponding to:
475    ///
476    /// Float => [full](burn_backend::ops::FloatTensorOps::float_full).
477    /// Int => [full](burn_backend::ops::IntTensorOps::int_full).
478    Full(FullOpIr),
479    /// Operation corresponding to:
480    ///
481    /// Float => [mean dim](burn_backend::ops::FloatTensorOps::float_mean_dim).
482    /// Int => [mean dim](burn_backend::ops::IntTensorOps::int_mean_dim).
483    MeanDim(ReduceDimOpIr),
484    /// Operation corresponding to:
485    ///
486    /// Float => [mean](burn_backend::ops::FloatTensorOps::float_mean).
487    /// Int => [mean](burn_backend::ops::IntTensorOps::int_mean).
488    Mean(ReduceOpIr),
489    /// Operation corresponding to:
490    ///
491    /// Float => [sum](burn_backend::ops::FloatTensorOps::float_sum).
492    /// Int => [sum](burn_backend::ops::IntTensorOps::int_sum).
493    Sum(ReduceOpIr),
494    /// Operation corresponding to:
495    ///
496    /// Float => [sum dim](burn_backend::ops::FloatTensorOps::float_sum_dim).
497    /// Int => [sum dim](burn_backend::ops::IntTensorOps::int_sum_dim).
498    SumDim(ReduceDimOpIr),
499    /// Operation corresponding to:
500    ///
501    /// Float => [prod](burn_backend::ops::FloatTensorOps::float_prod).
502    /// Int => [prod](burn_backend::ops::IntTensorOps::int_prod).
503    Prod(ReduceOpIr),
504    /// Operation corresponding to:
505    ///
506    /// Float => [prod dim](burn_backend::ops::FloatTensorOps::float_prod_dim).
507    /// Int => [prod dim](burn_backend::ops::IntTensorOps::int_prod_dim).
508    ProdDim(ReduceDimOpIr),
509    /// Operation corresponding to:
510    ///
511    /// Float => [greater](burn_backend::ops::FloatTensorOps::float_greater).
512    /// Int => [greater](burn_backend::ops::IntTensorOps::int_greater).
513    Greater(BinaryOpIr),
514    /// Operation corresponding to:
515    ///
516    /// Float => [greater elem](burn_backend::ops::FloatTensorOps::float_greater_elem).
517    /// Int => [greater elem](burn_backend::ops::IntTensorOps::int_greater_elem).
518    GreaterElem(ScalarOpIr),
519    /// Operation corresponding to:
520    ///
521    /// Float => [greater equal](burn_backend::ops::FloatTensorOps::float_greater_elem).
522    /// Int => [greater elem](burn_backend::ops::IntTensorOps::int_greater_elem).
523    GreaterEqual(BinaryOpIr),
524    /// Operation corresponding to:
525    ///
526    /// Float => [greater equal elem](burn_backend::ops::FloatTensorOps::float_greater_equal_elem).
527    /// Int => [greater equal elem](burn_backend::ops::IntTensorOps::int_greater_equal_elem).
528    GreaterEqualElem(ScalarOpIr),
529    /// Operation corresponding to:
530    ///
531    /// Float => [lower](burn_backend::ops::FloatTensorOps::float_lower).
532    /// Int => [lower](burn_backend::ops::IntTensorOps::int_lower).
533    Lower(BinaryOpIr),
534    /// Operation corresponding to:
535    ///
536    /// Float => [lower elem](burn_backend::ops::FloatTensorOps::float_lower_elem).
537    /// Int => [lower elem](burn_backend::ops::IntTensorOps::int_lower_elem).
538    LowerElem(ScalarOpIr),
539    /// Operation corresponding to:
540    ///
541    /// Float => [lower equal](burn_backend::ops::FloatTensorOps::float_lower_equal).
542    /// Int => [lower equal](burn_backend::ops::IntTensorOps::int_lower_equal).
543    LowerEqual(BinaryOpIr),
544    /// Operation corresponding to:
545    ///
546    /// Float => [lower equal elem](burn_backend::ops::FloatTensorOps::float_lower_equal_elem).
547    /// Int => [lower equal elem](burn_backend::ops::IntTensorOps::int_lower_equal_elem).
548    LowerEqualElem(ScalarOpIr),
549    /// Operation corresponding to:
550    ///
551    /// Float => [argmax](burn_backend::ops::FloatTensorOps::float_argmax).
552    /// Int => [argmax](burn_backend::ops::IntTensorOps::int_argmax).
553    ArgMax(ReduceDimOpIr),
554    /// Operation corresponding to:
555    ///
556    /// Float => [argtopk](burn_backend::ops::FloatTensorOps::float_argtopk).
557    /// Int => [argtopk](burn_backend::ops::IntTensorOps::int_argtopk).
558    ArgTopK(ReduceDimOpIr),
559    /// Operation corresponding to:
560    ///
561    /// Float => [topk](burn_backend::ops::FloatTensorOps::float_topk).
562    /// Int => [topk](burn_backend::ops::IntTensorOps::int_topk).
563    TopK(ReduceDimOpIr),
564    /// Operation corresponding to:
565    ///
566    /// Float => [argmin](burn_backend::ops::FloatTensorOps::float_argmin).
567    /// Int => [argmin](burn_backend::ops::IntTensorOps::int_argmin).
568    ArgMin(ReduceDimOpIr),
569    /// Operation corresponding to:
570    ///
571    /// Float => [max](burn_backend::ops::FloatTensorOps::float_max).
572    /// Int => [max](burn_backend::ops::IntTensorOps::int_max).
573    Max(ReduceOpIr),
574    /// Operation corresponding to:
575    ///
576    /// Float => [max dim with indices](burn_backend::ops::FloatTensorOps::float_max_dim_with_indices).
577    /// Int => [max dim with indices](burn_backend::ops::IntTensorOps::int_max_dim_with_indices).
578    MaxDimWithIndices(ReduceDimWithIndicesOpIr),
579    /// Operation corresponding to:
580    ///
581    /// Float => [min dim with indices](burn_backend::ops::FloatTensorOps::float_min_dim_with_indices).
582    /// Int => [min dim with indices](burn_backend::ops::IntTensorOps::int_min_dim_with_indices).
583    MinDimWithIndices(ReduceDimWithIndicesOpIr),
584    /// Operation corresponding to:
585    ///
586    /// Float => [min](burn_backend::ops::FloatTensorOps::float_min).
587    /// Int => [min](burn_backend::ops::IntTensorOps::int_min).
588    Min(ReduceOpIr),
589    /// Operation corresponding to:
590    ///
591    /// Float => [max dim](burn_backend::ops::FloatTensorOps::float_max_dim).
592    /// Int => [max dim](burn_backend::ops::IntTensorOps::int_max_dim).
593    MaxDim(ReduceDimOpIr),
594    /// Operation corresponding to:
595    ///
596    /// Float => [min dim](burn_backend::ops::FloatTensorOps::float_min_dim).
597    /// Int => [min dim](burn_backend::ops::IntTensorOps::int_min_dim).
598    MinDim(ReduceDimOpIr),
599    /// Operation corresponding to:
600    ///
601    /// Float => [max_abs](burn_backend::ops::FloatTensorOps::float_max_abs).
602    /// Int => [max_abs](burn_backend::ops::IntTensorOps::int_max_abs).
603    MaxAbs(ReduceOpIr),
604    /// Operation corresponding to:
605    ///
606    /// Float => [max_abs dim](burn_backend::ops::FloatTensorOps::float_max_abs_dim).
607    /// Int => [max_abs dim](burn_backend::ops::IntTensorOps::int_max_abs_dim).
608    MaxAbsDim(ReduceDimOpIr),
609    /// Operation corresponding to:
610    ///
611    /// Float => [clamp](burn_backend::ops::FloatTensorOps::float_clamp).
612    /// Int => [clamp](burn_backend::ops::IntTensorOps::int_clamp).
613    Clamp(ClampOpIr),
614    /// Operation corresponding to:
615    ///
616    /// Int => [random](burn_backend::ops::IntTensorOps::int_random).
617    IntRandom(RandomOpIr),
618    /// Operation corresponding to:
619    ///
620    /// Float => [powf](burn_backend::ops::FloatTensorOps::float_powi).
621    /// Int => [powf](burn_backend::ops::IntTensorOps::int_powi).
622    Powi(BinaryOpIr),
623    /// Operation corresponding to:
624    ///
625    /// Float => [powi_scalar](burn_backend::ops::FloatTensorOps::float_powi_scalar).
626    /// Int => [powi_scalar](burn_backend::ops::IntTensorOps::int_powi_scalar).
627    PowiScalar(ScalarOpIr),
628    /// Operation corresponding to:
629    ///
630    /// Float => [cumsum](burn_backend::ops::FloatTensorOps::float_cumsum).
631    /// Int => [cumsum](burn_backend::ops::IntTensorOps::int_cumsum).
632    CumSum(DimOpIr),
633    /// Operation corresponding to:
634    ///
635    /// Float => [cumprod](burn_backend::ops::FloatTensorOps::float_cumprod).
636    /// Int => [cumprod](burn_backend::ops::IntTensorOps::int_cumprod).
637    CumProd(DimOpIr),
638    /// Operation corresponding to:
639    ///
640    /// Float => [cummin](burn_backend::ops::FloatTensorOps::float_cummin).
641    /// Int => [cummin](burn_backend::ops::IntTensorOps::int_cummin).
642    CumMin(DimOpIr),
643    /// Operation corresponding to:
644    ///
645    /// Float => [cummax](burn_backend::ops::FloatTensorOps::float_cummax).
646    /// Int => [cummax](burn_backend::ops::IntTensorOps::int_cummax).
647    CumMax(DimOpIr),
648}
649
650/// Operation intermediate representation specific to an int tensor.
651#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
652pub enum IntOperationIr {
653    /// Operation corresponding to [into float](burn_backend::ops::IntTensorOps::int_into_float).
654    IntoFloat(CastOpIr),
655    /// Operation corresponding to:
656    ///
657    /// Int => [bitwise and](burn_backend::ops::IntTensorOps::bitwise_and).
658    BitwiseAnd(BinaryOpIr),
659    /// Operation corresponding to:
660    ///
661    /// Int => [bitwise and scalar](burn_backend::ops::IntTensorOps::bitwise_and_scalar).
662    BitwiseAndScalar(ScalarOpIr),
663    /// Operation corresponding to:
664    ///
665    /// Int => [bitwise or](burn_backend::ops::IntTensorOps::bitwise_or).
666    BitwiseOr(BinaryOpIr),
667    /// Operation corresponding to:
668    ///
669    /// Int => [bitwise or scalar](burn_backend::ops::IntTensorOps::bitwise_or_scalar).
670    BitwiseOrScalar(ScalarOpIr),
671    /// Operation corresponding to:
672    ///
673    /// Int => [bitwise xor](burn_backend::ops::IntTensorOps::bitwise_xor).
674    BitwiseXor(BinaryOpIr),
675    /// Operation corresponding to:
676    ///
677    /// Int => [bitwise xor scalar](burn_backend::ops::IntTensorOps::bitwise_xor_scalar).
678    BitwiseXorScalar(ScalarOpIr),
679    /// Operation corresponding to:
680    ///
681    /// Int => [bitwise not](burn_backend::ops::IntTensorOps::bitwise_not).
682    BitwiseNot(UnaryOpIr),
683    /// Operation corresponding to:
684    ///
685    /// Int => [bitwise left shift](burn_backend::ops::IntTensorOps::bitwise_left_shift).
686    BitwiseLeftShift(BinaryOpIr),
687    /// Operation corresponding to:
688    ///
689    /// Int => [bitwise left shift scalar](burn_backend::ops::IntTensorOps::bitwise_left_shift_scalar).
690    BitwiseLeftShiftScalar(ScalarOpIr),
691    /// Operation corresponding to:
692    ///
693    /// Int => [bitwise right shift](burn_backend::ops::IntTensorOps::bitwise_right_shift).
694    BitwiseRightShift(BinaryOpIr),
695    /// Operation corresponding to:
696    ///
697    /// Int => [bitwise right shift scalar](burn_backend::ops::IntTensorOps::bitwise_right_shift_scalar).
698    BitwiseRightShiftScalar(ScalarOpIr),
699    /// Operation corresponding to [matmul](burn_backend::ops::IntTensorOps::int_matmul).
700    Matmul(MatmulOpIr),
701}
702
703/// Operation intermediate representation specific to a bool tensor.
704#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
705pub enum BoolOperationIr {
706    /// Operation corresponding to [into float](burn_backend::ops::BoolTensorOps::bool_into_float).
707    IntoFloat(CastOpIr),
708    /// Operation corresponding to [into int](burn_backend::ops::BoolTensorOps::bool_into_int).
709    IntoInt(CastOpIr),
710    /// Operation corresponding to [not](burn_backend::ops::BoolTensorOps::bool_not).
711    Not(UnaryOpIr),
712    /// Operation corresponding to [and](burn_backend::ops::BoolTensorOps::bool_and).
713    And(BinaryOpIr),
714    /// Operation corresponding to [or](burn_backend::ops::BoolTensorOps::bool_or).
715    Or(BinaryOpIr),
716}
717
718#[cfg(feature = "distributed")]
719/// Operations that can be done on distributed tensors.
720#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
721pub enum DistributedOperationIr {
722    /// Operation corresponding to:
723    /// [all_reduce](burn_backend::distributed::DistributedBackend::all_reduce).
724    AllReduce(AllReduceOpIr),
725}
726
727/// Swap dim operation intermediate representation.
728#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
729pub struct SwapDimsOpIr {
730    /// Input tensor intermediate representation.
731    pub input: TensorIr,
732    /// Output tensor intermediate representation.
733    pub out: TensorIr,
734    /// The first dim to swap.
735    pub dim1: usize,
736    /// The second dim to swap.
737    pub dim2: usize,
738}
739
740/// Permute operation intermediate representation.
741#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
742pub struct PermuteOpIr {
743    /// Input tensor intermediate representation.
744    pub input: TensorIr,
745    /// Output tensor intermediate representation.
746    pub out: TensorIr,
747    /// The new order of the dimensions.
748    pub axes: Vec<usize>,
749}
750
751/// Shape operation intermediate representation.
752#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
753pub struct ShapeOpIr {
754    /// Input tensor intermediate representation.
755    pub input: TensorIr,
756    /// Output tensor intermediate representation with the new shape.
757    pub out: TensorIr,
758}
759
760/// Unfold operation intermediate representation.
761#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
762pub struct UnfoldOpIr {
763    /// Input tensor intermediate representation.
764    pub input: TensorIr,
765    /// Output tensor intermediate representation.
766    pub out: TensorIr,
767
768    /// The selected dim.
769    pub dim: usize,
770    /// The window size.
771    pub size: usize,
772    /// The window step along dim.
773    pub step: usize,
774}
775
776/// Flip operation intermediate representation.
777#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
778pub struct FlipOpIr {
779    /// Input tensor intermediate representation.
780    pub input: TensorIr,
781    /// Output tensor intermediate representation.
782    pub out: TensorIr,
783    /// The dimensions to flip.
784    pub axes: Vec<usize>,
785}
786
787#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
788#[allow(missing_docs)]
789pub struct RandomOpIr {
790    pub out: TensorIr,
791    pub distribution: Distribution,
792}
793
794/// Creation operation intermediate representation.
795/// As opposed to [InitOperationIr], creation operations are lazy initialized.
796#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
797pub struct CreationOpIr {
798    /// Output tensor intermediate representation.
799    pub out: TensorIr,
800}
801
802/// Full operation intermediate representation.
803#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
804pub struct FullOpIr {
805    /// Output tensor intermediate representation.
806    pub out: TensorIr,
807    /// Fill value.
808    pub value: ScalarIr,
809}
810
811#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
812/// Declares a tensor has been initialized.
813///
814/// It is necessary to register for proper orphan detection and avoid memory leak.
815pub struct InitOperationIr {
816    /// The initialized tensor.
817    pub out: TensorIr,
818}
819
820#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
821#[allow(missing_docs)]
822pub struct BinaryOpIr {
823    pub lhs: TensorIr,
824    pub rhs: TensorIr,
825    pub out: TensorIr,
826}
827
828#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
829#[allow(missing_docs)]
830pub struct MatmulOpIr {
831    pub lhs: TensorIr,
832    pub rhs: TensorIr,
833    pub out: TensorIr,
834}
835
836#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
837#[allow(missing_docs)]
838pub struct CrossOpIr {
839    pub lhs: TensorIr,
840    pub rhs: TensorIr,
841    pub out: TensorIr,
842    pub dim: usize,
843}
844
845#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
846#[allow(missing_docs)]
847pub struct UnaryOpIr {
848    pub input: TensorIr,
849    pub out: TensorIr,
850}
851
852#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
853#[allow(missing_docs)]
854pub struct ScalarOpIr {
855    pub lhs: TensorIr,
856    // TODO: Make that an enum with `Value` and `Id` variants for relative/global
857    // conversion.
858    pub rhs: ScalarIr,
859    pub out: TensorIr,
860}
861
862#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]
863#[allow(missing_docs)]
864pub struct ReduceOpIr {
865    pub input: TensorIr,
866    pub out: TensorIr,
867}
868
869#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]
870#[allow(missing_docs)]
871pub struct ReduceDimOpIr {
872    pub input: TensorIr,
873    pub out: TensorIr,
874    pub axis: usize,
875    pub accumulator_len: usize,
876}
877
878#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
879#[allow(missing_docs)]
880pub struct CastOpIr {
881    pub input: TensorIr,
882    pub out: TensorIr,
883}
884
885/// IR for operations that operate along a dimension without reducing it.
886/// Unlike `ReduceDimOpIr`, the output shape is the same as the input shape.
887#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]
888#[allow(missing_docs)]
889pub struct DimOpIr {
890    pub input: TensorIr,
891    pub out: TensorIr,
892    pub axis: usize,
893}
894
895#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
896#[allow(missing_docs)]
897pub struct GatherOpIr {
898    pub tensor: TensorIr,
899    pub dim: usize,
900    pub indices: TensorIr,
901    pub out: TensorIr,
902}
903
904#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
905#[allow(missing_docs)]
906pub struct ScatterOpIr {
907    pub tensor: TensorIr,
908    pub dim: usize,
909    pub indices: TensorIr,
910    pub value: TensorIr,
911    pub update: IndexingUpdateOp,
912    pub out: TensorIr,
913}
914
915#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
916#[allow(missing_docs)]
917pub struct ScatterNdOpIr {
918    pub data: TensorIr,
919    pub indices: TensorIr,
920    pub values: TensorIr,
921    pub reduction: IndexingUpdateOp,
922    pub out: TensorIr,
923}
924
925#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
926#[allow(missing_docs)]
927pub struct GatherNdOpIr {
928    pub data: TensorIr,
929    pub indices: TensorIr,
930    pub out: TensorIr,
931}
932
933#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
934#[allow(missing_docs)]
935pub struct SelectOpIr {
936    pub tensor: TensorIr,
937    pub dim: usize,
938    pub indices: TensorIr,
939    pub out: TensorIr,
940}
941
942#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
943#[allow(missing_docs)]
944pub struct SelectAssignOpIr {
945    pub tensor: TensorIr,
946    pub dim: usize,
947    pub indices: TensorIr,
948    pub value: TensorIr,
949    pub update: IndexingUpdateOp,
950    pub out: TensorIr,
951}
952
953#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
954#[allow(missing_docs)]
955pub struct SliceOpIr {
956    pub tensor: TensorIr,
957    pub ranges: Vec<Slice>,
958    pub out: TensorIr,
959}
960
961#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
962#[allow(missing_docs)]
963pub struct SliceAssignOpIr {
964    pub tensor: TensorIr,
965    pub ranges: Vec<burn_backend::Slice>,
966    pub value: TensorIr,
967    pub out: TensorIr,
968}
969
970#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
971#[allow(missing_docs)]
972pub struct MaskWhereOpIr {
973    pub tensor: TensorIr,
974    pub mask: TensorIr,
975    pub value: TensorIr,
976    pub out: TensorIr,
977}
978
979#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
980#[allow(missing_docs)]
981pub struct MaskFillOpIr {
982    pub tensor: TensorIr,
983    pub mask: TensorIr,
984    pub value: ScalarIr,
985    pub out: TensorIr,
986}
987
988#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
989#[allow(missing_docs)]
990pub struct ClampOpIr {
991    pub tensor: TensorIr,
992    pub min: ScalarIr,
993    pub max: ScalarIr,
994    pub out: TensorIr,
995}
996
997#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
998#[allow(missing_docs)]
999pub struct RepeatDimOpIr {
1000    pub tensor: TensorIr,
1001    pub dim: usize,
1002    pub times: usize,
1003    pub out: TensorIr,
1004}
1005
1006#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1007#[allow(missing_docs)]
1008pub struct CatOpIr {
1009    pub tensors: Vec<TensorIr>,
1010    pub dim: usize,
1011    pub out: TensorIr,
1012}
1013
1014#[cfg(feature = "distributed")]
1015#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1016#[allow(missing_docs)]
1017pub struct AllReduceOpIr {
1018    pub tensor: TensorIr,
1019    pub out: TensorIr,
1020}
1021
1022#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1023#[allow(missing_docs)]
1024pub struct ReduceDimWithIndicesOpIr {
1025    pub tensor: TensorIr,
1026    pub dim: usize,
1027    pub out: TensorIr,
1028    pub out_indices: TensorIr,
1029}
1030
1031#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1032#[allow(missing_docs)]
1033pub struct EmbeddingOpIr {
1034    pub weights: TensorIr,
1035    pub indices: TensorIr,
1036    pub out: TensorIr,
1037}
1038
1039#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1040#[allow(missing_docs)]
1041pub struct EmbeddingBackwardOpIr {
1042    pub weights: TensorIr,
1043    pub out_grad: TensorIr,
1044    pub indices: TensorIr,
1045    pub out: TensorIr,
1046}
1047
1048#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1049#[allow(missing_docs)]
1050pub struct LinearOpIr {
1051    pub x: TensorIr,
1052    pub weight: TensorIr,
1053    pub bias: Option<TensorIr>,
1054    pub out: TensorIr,
1055}
1056
1057#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1058#[allow(missing_docs)]
1059pub struct LinearXBackwardOpIr {
1060    pub weight: TensorIr,
1061    pub output_grad: TensorIr,
1062    pub out: TensorIr,
1063}
1064
1065#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1066#[allow(missing_docs)]
1067pub struct LinearWeightBackwardOpIr {
1068    pub x: TensorIr,
1069    pub output_grad: TensorIr,
1070    pub out: TensorIr,
1071}
1072
1073#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1074#[allow(missing_docs)]
1075pub struct LinearBiasBackwardOpIr {
1076    pub output_grad: TensorIr,
1077    pub out: TensorIr,
1078}
1079
1080#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1081#[allow(missing_docs)]
1082pub struct Conv1dOpIr {
1083    pub x: TensorIr,
1084    pub weight: TensorIr,
1085    pub bias: Option<TensorIr>,
1086    pub options: Conv1dOptionsIr,
1087    pub out: TensorIr,
1088}
1089
1090#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1091#[allow(missing_docs)]
1092pub struct Conv1dXBackwardOpIr {
1093    pub x: TensorIr,
1094    pub weight: TensorIr,
1095    pub output_grad: TensorIr,
1096    pub options: Conv1dOptionsIr,
1097    pub out: TensorIr,
1098}
1099
1100#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1101#[allow(missing_docs)]
1102pub struct Conv1dWeightBackwardOpIr {
1103    pub x: TensorIr,
1104    pub weight: TensorIr,
1105    pub output_grad: TensorIr,
1106    pub options: Conv1dOptionsIr,
1107    pub out: TensorIr,
1108}
1109
1110#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1111#[allow(missing_docs)]
1112pub struct Conv1dBiasBackwardOpIr {
1113    pub x: TensorIr,
1114    pub bias: TensorIr,
1115    pub output_grad: TensorIr,
1116    pub out: TensorIr,
1117}
1118
1119#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1120#[allow(missing_docs)]
1121pub struct Conv2dOpIr {
1122    pub x: TensorIr,
1123    pub weight: TensorIr,
1124    pub bias: Option<TensorIr>,
1125    pub options: Conv2dOptionsIr,
1126    pub out: TensorIr,
1127}
1128
1129#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1130#[allow(missing_docs)]
1131pub struct Conv2dXBackwardOpIr {
1132    pub x: TensorIr,
1133    pub weight: TensorIr,
1134    pub output_grad: TensorIr,
1135    pub options: Conv2dOptionsIr,
1136    pub out: TensorIr,
1137}
1138
1139#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1140#[allow(missing_docs)]
1141pub struct Conv2dWeightBackwardOpIr {
1142    pub x: TensorIr,
1143    pub weight: TensorIr,
1144    pub output_grad: TensorIr,
1145    pub options: Conv2dOptionsIr,
1146    pub out: TensorIr,
1147}
1148
1149#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1150#[allow(missing_docs)]
1151pub struct Conv2dBiasBackwardOpIr {
1152    pub x: TensorIr,
1153    pub bias: TensorIr,
1154    pub output_grad: TensorIr,
1155    pub out: TensorIr,
1156}
1157
1158#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1159#[allow(missing_docs)]
1160pub struct DeformConv2dOpIr {
1161    pub x: TensorIr,
1162    pub offset: TensorIr,
1163    pub weight: TensorIr,
1164    pub mask: Option<TensorIr>,
1165    pub bias: Option<TensorIr>,
1166    pub options: DeformableConv2dOptionsIr,
1167    pub out: TensorIr,
1168}
1169
1170#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1171#[allow(missing_docs)]
1172pub struct DeformConv2dBackwardOpIr {
1173    pub x: TensorIr,
1174    pub offset: TensorIr,
1175    pub weight: TensorIr,
1176    pub mask: Option<TensorIr>,
1177    pub bias: Option<TensorIr>,
1178    pub out_grad: TensorIr,
1179    pub options: DeformableConv2dOptionsIr,
1180    pub input_grad: TensorIr,
1181    pub offset_grad: TensorIr,
1182    pub weight_grad: TensorIr,
1183    pub mask_grad: Option<TensorIr>,
1184    pub bias_grad: Option<TensorIr>,
1185}
1186
1187#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1188#[allow(missing_docs)]
1189pub struct Conv3dOpIr {
1190    pub x: TensorIr,
1191    pub weight: TensorIr,
1192    pub bias: Option<TensorIr>,
1193    pub options: Conv3dOptionsIr,
1194    pub out: TensorIr,
1195}
1196
1197#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1198#[allow(missing_docs)]
1199pub struct Conv3dXBackwardOpIr {
1200    pub x: TensorIr,
1201    pub weight: TensorIr,
1202    pub output_grad: TensorIr,
1203    pub options: Conv3dOptionsIr,
1204    pub out: TensorIr,
1205}
1206
1207#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1208#[allow(missing_docs)]
1209pub struct Conv3dWeightBackwardOpIr {
1210    pub x: TensorIr,
1211    pub weight: TensorIr,
1212    pub output_grad: TensorIr,
1213    pub options: Conv3dOptionsIr,
1214    pub out: TensorIr,
1215}
1216
1217#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1218#[allow(missing_docs)]
1219pub struct Conv3dBiasBackwardOpIr {
1220    pub x: TensorIr,
1221    pub bias: TensorIr,
1222    pub output_grad: TensorIr,
1223    pub out: TensorIr,
1224}
1225
1226#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1227#[allow(missing_docs)]
1228pub struct ConvTranspose1dOpIr {
1229    pub x: TensorIr,
1230    pub weight: TensorIr,
1231    pub bias: Option<TensorIr>,
1232    pub options: ConvTranspose1dOptionsIr,
1233    pub out: TensorIr,
1234}
1235
1236#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1237#[allow(missing_docs)]
1238pub struct ConvTranspose2dOpIr {
1239    pub x: TensorIr,
1240    pub weight: TensorIr,
1241    pub bias: Option<TensorIr>,
1242    pub options: ConvTranspose2dOptionsIr,
1243    pub out: TensorIr,
1244}
1245
1246#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1247#[allow(missing_docs)]
1248pub struct ConvTranspose3dOpIr {
1249    pub x: TensorIr,
1250    pub weight: TensorIr,
1251    pub bias: Option<TensorIr>,
1252    pub options: ConvTranspose3dOptionsIr,
1253    pub out: TensorIr,
1254}
1255
1256#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1257#[allow(missing_docs)]
1258pub struct Conv1dOptionsIr {
1259    pub stride: [usize; 1],
1260    pub padding: [usize; 1],
1261    pub dilation: [usize; 1],
1262    pub groups: usize,
1263}
1264
1265#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1266#[allow(missing_docs)]
1267pub struct Conv2dOptionsIr {
1268    pub stride: [usize; 2],
1269    pub padding: [usize; 2],
1270    pub dilation: [usize; 2],
1271    pub groups: usize,
1272}
1273
1274#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1275#[allow(missing_docs)]
1276pub struct DeformableConv2dOptionsIr {
1277    pub stride: [usize; 2],
1278    pub padding: [usize; 2],
1279    pub dilation: [usize; 2],
1280    pub weight_groups: usize,
1281    pub offset_groups: usize,
1282}
1283
1284#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1285#[allow(missing_docs)]
1286pub struct Conv3dOptionsIr {
1287    pub stride: [usize; 3],
1288    pub padding: [usize; 3],
1289    pub dilation: [usize; 3],
1290    pub groups: usize,
1291}
1292
1293#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1294#[allow(missing_docs)]
1295pub struct ConvTranspose1dOptionsIr {
1296    pub stride: [usize; 1],
1297    pub padding: [usize; 1],
1298    pub padding_out: [usize; 1],
1299    pub dilation: [usize; 1],
1300    pub groups: usize,
1301}
1302
1303#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1304#[allow(missing_docs)]
1305pub struct ConvTranspose2dOptionsIr {
1306    pub stride: [usize; 2],
1307    pub padding: [usize; 2],
1308    pub padding_out: [usize; 2],
1309    pub dilation: [usize; 2],
1310    pub groups: usize,
1311}
1312
1313#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1314#[allow(missing_docs)]
1315pub struct ConvTranspose3dOptionsIr {
1316    pub stride: [usize; 3],
1317    pub padding: [usize; 3],
1318    pub padding_out: [usize; 3],
1319    pub dilation: [usize; 3],
1320    pub groups: usize,
1321}
1322
1323/// Quantization parameters intermediate representation.
1324#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
1325pub struct QuantizationParametersIr {
1326    /// The scaling factor.
1327    pub scales: TensorIr,
1328}
1329
1330#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1331#[allow(missing_docs)]
1332pub struct QuantizeOpIr {
1333    pub tensor: TensorIr,
1334    pub qparams: QuantizationParametersIr,
1335    pub scheme: QuantScheme,
1336    pub out: TensorIr,
1337}
1338
1339#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1340#[allow(missing_docs)]
1341pub struct DequantizeOpIr {
1342    pub input: TensorIr,
1343    pub out: TensorIr,
1344}
1345
1346impl From<ConvOptions<1>> for Conv1dOptionsIr {
1347    fn from(value: ConvOptions<1>) -> Self {
1348        Self {
1349            stride: value.stride,
1350            padding: value.padding,
1351            dilation: value.dilation,
1352            groups: value.groups,
1353        }
1354    }
1355}
1356
1357impl From<ConvOptions<2>> for Conv2dOptionsIr {
1358    fn from(value: ConvOptions<2>) -> Self {
1359        Self {
1360            stride: value.stride,
1361            padding: value.padding,
1362            dilation: value.dilation,
1363            groups: value.groups,
1364        }
1365    }
1366}
1367
1368impl From<ConvOptions<3>> for Conv3dOptionsIr {
1369    fn from(value: ConvOptions<3>) -> Self {
1370        Self {
1371            stride: value.stride,
1372            padding: value.padding,
1373            dilation: value.dilation,
1374            groups: value.groups,
1375        }
1376    }
1377}
1378
1379impl From<DeformConvOptions<2>> for DeformableConv2dOptionsIr {
1380    fn from(value: DeformConvOptions<2>) -> Self {
1381        Self {
1382            stride: value.stride,
1383            padding: value.padding,
1384            dilation: value.dilation,
1385            weight_groups: value.weight_groups,
1386            offset_groups: value.offset_groups,
1387        }
1388    }
1389}
1390
1391impl From<ConvTransposeOptions<1>> for ConvTranspose1dOptionsIr {
1392    fn from(value: ConvTransposeOptions<1>) -> Self {
1393        Self {
1394            stride: value.stride,
1395            padding: value.padding,
1396            padding_out: value.padding_out,
1397            dilation: value.dilation,
1398            groups: value.groups,
1399        }
1400    }
1401}
1402
1403impl From<ConvTransposeOptions<2>> for ConvTranspose2dOptionsIr {
1404    fn from(value: ConvTransposeOptions<2>) -> Self {
1405        Self {
1406            stride: value.stride,
1407            padding: value.padding,
1408            padding_out: value.padding_out,
1409            dilation: value.dilation,
1410            groups: value.groups,
1411        }
1412    }
1413}
1414
1415impl From<ConvTransposeOptions<3>> for ConvTranspose3dOptionsIr {
1416    fn from(value: ConvTransposeOptions<3>) -> Self {
1417        Self {
1418            stride: value.stride,
1419            padding: value.padding,
1420            padding_out: value.padding_out,
1421            dilation: value.dilation,
1422            groups: value.groups,
1423        }
1424    }
1425}
1426
1427impl From<Conv1dOptionsIr> for ConvOptions<1> {
1428    fn from(val: Conv1dOptionsIr) -> Self {
1429        ConvOptions {
1430            stride: val.stride,
1431            padding: val.padding,
1432            dilation: val.dilation,
1433            groups: val.groups,
1434        }
1435    }
1436}
1437
1438impl From<Conv2dOptionsIr> for ConvOptions<2> {
1439    fn from(val: Conv2dOptionsIr) -> Self {
1440        ConvOptions {
1441            stride: val.stride,
1442            padding: val.padding,
1443            dilation: val.dilation,
1444            groups: val.groups,
1445        }
1446    }
1447}
1448
1449impl From<Conv3dOptionsIr> for ConvOptions<3> {
1450    fn from(val: Conv3dOptionsIr) -> Self {
1451        ConvOptions {
1452            stride: val.stride,
1453            padding: val.padding,
1454            dilation: val.dilation,
1455            groups: val.groups,
1456        }
1457    }
1458}
1459
1460impl From<DeformableConv2dOptionsIr> for DeformConvOptions<2> {
1461    fn from(value: DeformableConv2dOptionsIr) -> Self {
1462        DeformConvOptions {
1463            stride: value.stride,
1464            padding: value.padding,
1465            dilation: value.dilation,
1466            weight_groups: value.weight_groups,
1467            offset_groups: value.offset_groups,
1468        }
1469    }
1470}
1471
1472impl From<ConvTranspose1dOptionsIr> for ConvTransposeOptions<1> {
1473    fn from(val: ConvTranspose1dOptionsIr) -> Self {
1474        ConvTransposeOptions {
1475            stride: val.stride,
1476            padding: val.padding,
1477            padding_out: val.padding_out,
1478            dilation: val.dilation,
1479            groups: val.groups,
1480        }
1481    }
1482}
1483
1484impl From<ConvTranspose2dOptionsIr> for ConvTransposeOptions<2> {
1485    fn from(val: ConvTranspose2dOptionsIr) -> Self {
1486        ConvTransposeOptions {
1487            stride: val.stride,
1488            padding: val.padding,
1489            padding_out: val.padding_out,
1490            dilation: val.dilation,
1491            groups: val.groups,
1492        }
1493    }
1494}
1495
1496impl From<ConvTranspose3dOptionsIr> for ConvTransposeOptions<3> {
1497    fn from(val: ConvTranspose3dOptionsIr) -> Self {
1498        ConvTransposeOptions {
1499            stride: val.stride,
1500            padding: val.padding,
1501            padding_out: val.padding_out,
1502            dilation: val.dilation,
1503            groups: val.groups,
1504        }
1505    }
1506}
1507
1508#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1509#[allow(missing_docs)]
1510pub struct AvgPool1dOpIr {
1511    pub x: TensorIr,
1512    pub kernel_size: usize,
1513    pub stride: usize,
1514    pub padding: usize,
1515    pub count_include_pad: bool,
1516    pub ceil_mode: bool,
1517    pub out: TensorIr,
1518}
1519
1520#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1521#[allow(missing_docs)]
1522pub struct AvgPool2dOpIr {
1523    pub x: TensorIr,
1524    pub kernel_size: [usize; 2],
1525    pub stride: [usize; 2],
1526    pub padding: [usize; 2],
1527    pub count_include_pad: bool,
1528    pub ceil_mode: bool,
1529    pub out: TensorIr,
1530}
1531
1532#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1533#[allow(missing_docs)]
1534pub struct AvgPool1dBackwardOpIr {
1535    pub x: TensorIr,
1536    pub grad: TensorIr,
1537    pub kernel_size: usize,
1538    pub stride: usize,
1539    pub padding: usize,
1540    pub count_include_pad: bool,
1541    pub ceil_mode: bool,
1542    pub out: TensorIr,
1543}
1544
1545#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1546#[allow(missing_docs)]
1547pub struct AvgPool2dBackwardOpIr {
1548    pub x: TensorIr,
1549    pub grad: TensorIr,
1550    pub kernel_size: [usize; 2],
1551    pub stride: [usize; 2],
1552    pub padding: [usize; 2],
1553    pub count_include_pad: bool,
1554    pub ceil_mode: bool,
1555    pub out: TensorIr,
1556}
1557
1558#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1559#[allow(missing_docs)]
1560pub struct AdaptiveAvgPool1dOpIr {
1561    pub x: TensorIr,
1562    pub output_size: usize,
1563    pub out: TensorIr,
1564}
1565
1566#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1567#[allow(missing_docs)]
1568pub struct AdaptiveAvgPool2dOpIr {
1569    pub x: TensorIr,
1570    pub output_size: [usize; 2],
1571    pub out: TensorIr,
1572}
1573
1574#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1575#[allow(missing_docs)]
1576pub struct AdaptiveAvgPool1dBackwardOpIr {
1577    pub x: TensorIr,
1578    pub grad: TensorIr,
1579    pub out: TensorIr,
1580}
1581
1582#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1583#[allow(missing_docs)]
1584pub struct AdaptiveAvgPool2dBackwardOpIr {
1585    pub x: TensorIr,
1586    pub grad: TensorIr,
1587    pub out: TensorIr,
1588}
1589
1590#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1591#[allow(missing_docs)]
1592pub struct MaxPool1dOpIr {
1593    pub x: TensorIr,
1594    pub kernel_size: usize,
1595    pub stride: usize,
1596    pub padding: usize,
1597    pub dilation: usize,
1598    pub ceil_mode: bool,
1599    pub out: TensorIr,
1600}
1601
1602#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1603#[allow(missing_docs)]
1604pub struct MaxPool1dWithIndicesOpIr {
1605    pub x: TensorIr,
1606    pub kernel_size: usize,
1607    pub stride: usize,
1608    pub padding: usize,
1609    pub dilation: usize,
1610    pub ceil_mode: bool,
1611    pub out: TensorIr,
1612    pub out_indices: TensorIr,
1613}
1614
1615#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1616#[allow(missing_docs)]
1617pub struct MaxPool1dWithIndicesBackwardOpIr {
1618    pub x: TensorIr,
1619    pub grad: TensorIr,
1620    pub indices: TensorIr,
1621    pub kernel_size: usize,
1622    pub stride: usize,
1623    pub padding: usize,
1624    pub dilation: usize,
1625    pub ceil_mode: bool,
1626    pub out: TensorIr,
1627}
1628
1629#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1630#[allow(missing_docs)]
1631pub struct MaxPool2dOpIr {
1632    pub x: TensorIr,
1633    pub kernel_size: [usize; 2],
1634    pub stride: [usize; 2],
1635    pub padding: [usize; 2],
1636    pub dilation: [usize; 2],
1637    pub ceil_mode: bool,
1638    pub out: TensorIr,
1639}
1640
1641#[allow(missing_docs)]
1642#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1643pub struct MaxPool2dWithIndicesOpIr {
1644    pub x: TensorIr,
1645    pub kernel_size: [usize; 2],
1646    pub stride: [usize; 2],
1647    pub padding: [usize; 2],
1648    pub dilation: [usize; 2],
1649    pub ceil_mode: bool,
1650    pub out: TensorIr,
1651    pub out_indices: TensorIr,
1652}
1653
1654#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1655#[allow(missing_docs)]
1656pub struct MaxPool2dWithIndicesBackwardOpIr {
1657    pub x: TensorIr,
1658    pub grad: TensorIr,
1659    pub indices: TensorIr,
1660    pub kernel_size: [usize; 2],
1661    pub stride: [usize; 2],
1662    pub padding: [usize; 2],
1663    pub dilation: [usize; 2],
1664    pub ceil_mode: bool,
1665    pub out: TensorIr,
1666}
1667
1668#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1669#[allow(missing_docs)]
1670pub enum InterpolateModeIr {
1671    Nearest,
1672    Bilinear,
1673    Bicubic,
1674    Lanczos3,
1675}
1676
1677#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1678#[allow(missing_docs)]
1679pub struct InterpolateOptionsIr {
1680    pub mode: InterpolateModeIr,
1681    pub align_corners: bool,
1682}
1683
1684#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1685#[allow(missing_docs)]
1686pub struct InterpolateOpIr {
1687    pub x: TensorIr,
1688    pub output_size: [usize; 2],
1689    pub options: InterpolateOptionsIr,
1690    pub out: TensorIr,
1691}
1692
1693#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1694#[allow(missing_docs)]
1695pub struct RfftOpIr {
1696    pub signal: TensorIr,
1697    pub dim: usize,
1698    pub n: Option<usize>,
1699    pub out_re: TensorIr,
1700    pub out_im: TensorIr,
1701}
1702
1703#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1704#[allow(missing_docs)]
1705pub struct IRfftOpIr {
1706    pub input_re: TensorIr,
1707    pub input_im: TensorIr,
1708    pub dim: usize,
1709    pub n: Option<usize>,
1710    pub out_signal: TensorIr,
1711}
1712
1713#[allow(missing_docs)]
1714impl RfftOpIr {
1715    pub fn create<F>(signal: TensorIr, dim: usize, n: Option<usize>, mut new_id: F) -> Self
1716    where
1717        F: FnMut() -> crate::TensorId,
1718    {
1719        // `n` is required to be a power of two at the public API boundary, so
1720        // the output has `n / 2 + 1` bins (matching scipy/torch for pow2 n).
1721        let mut shape = signal.shape.clone();
1722        let fft_len = n.unwrap_or(shape[dim]);
1723        shape[dim] = fft_len / 2 + 1;
1724        let dtype = signal.dtype;
1725
1726        Self {
1727            signal,
1728            dim,
1729            n,
1730            out_re: TensorIr::uninit(new_id(), shape.clone(), dtype),
1731            out_im: TensorIr::uninit(new_id(), shape, dtype),
1732        }
1733    }
1734}
1735
1736#[allow(missing_docs)]
1737impl IRfftOpIr {
1738    pub fn create<F>(
1739        input_re: TensorIr,
1740        input_im: TensorIr,
1741        dim: usize,
1742        n: Option<usize>,
1743        mut new_id: F,
1744    ) -> Self
1745    where
1746        F: FnMut() -> crate::TensorId,
1747    {
1748        debug_assert!(
1749            input_re.shape[dim] >= 1,
1750            "IRfftOpIr: input spectrum dimension must be >= 1"
1751        );
1752        debug_assert!(
1753            !matches!(n, Some(0)),
1754            "IRfftOpIr: n must be >= 1 when specified"
1755        );
1756        let mut shape = input_re.shape.clone();
1757        shape[dim] = n.unwrap_or((shape[dim] - 1) * 2);
1758        let dtype = input_re.dtype;
1759
1760        Self {
1761            input_re,
1762            input_im,
1763            dim,
1764            n,
1765            out_signal: TensorIr::uninit(new_id(), shape, dtype),
1766        }
1767    }
1768}
1769
1770#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1771#[allow(missing_docs)]
1772pub struct AttentionOptionsIr {
1773    pub scale: Option<ScalarIr>,
1774    pub softcap: Option<ScalarIr>,
1775    pub is_causal: bool,
1776}
1777
1778impl From<AttentionOptionsIr> for AttentionModuleOptions {
1779    fn from(ir: AttentionOptionsIr) -> Self {
1780        AttentionModuleOptions {
1781            scale: ir.scale.map(|s| s.elem()),
1782            softcap: ir.softcap.map(|s| s.elem()),
1783            is_causal: ir.is_causal,
1784        }
1785    }
1786}
1787
1788impl From<AttentionModuleOptions> for AttentionOptionsIr {
1789    fn from(ir: AttentionModuleOptions) -> Self {
1790        AttentionOptionsIr {
1791            scale: ir.scale.map(ScalarIr::Float),
1792            softcap: ir.softcap.map(ScalarIr::Float),
1793            is_causal: ir.is_causal,
1794        }
1795    }
1796}
1797
1798#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1799#[allow(missing_docs)]
1800pub struct AttentionOpIr {
1801    pub query: TensorIr,
1802    pub key: TensorIr,
1803    pub value: TensorIr,
1804    pub mask: Option<TensorIr>,
1805    pub attn_bias: Option<TensorIr>,
1806    pub options: AttentionOptionsIr,
1807    pub out: TensorIr,
1808}
1809
1810#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1811#[allow(missing_docs)]
1812pub struct CtcLossOpIr {
1813    pub log_probs: TensorIr,
1814    pub targets: TensorIr,
1815    pub input_lengths: TensorIr,
1816    pub target_lengths: TensorIr,
1817    pub blank: usize,
1818    pub out: TensorIr,
1819}
1820
1821#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1822#[allow(missing_docs)]
1823pub struct CtcLossBackwardOpIr {
1824    pub log_probs: TensorIr,
1825    pub targets: TensorIr,
1826    pub input_lengths: TensorIr,
1827    pub target_lengths: TensorIr,
1828    pub grad_loss: TensorIr,
1829    pub blank: usize,
1830    pub out: TensorIr,
1831}
1832
1833impl From<InterpolateModeIr> for InterpolateMode {
1834    fn from(val: InterpolateModeIr) -> Self {
1835        match val {
1836            InterpolateModeIr::Nearest => Self::Nearest,
1837            InterpolateModeIr::Bilinear => Self::Bilinear,
1838            InterpolateModeIr::Bicubic => Self::Bicubic,
1839            InterpolateModeIr::Lanczos3 => Self::Lanczos3,
1840        }
1841    }
1842}
1843
1844impl From<InterpolateOptionsIr> for InterpolateOptions {
1845    fn from(val: InterpolateOptionsIr) -> Self {
1846        Self::new(val.mode.into()).with_align_corners(val.align_corners)
1847    }
1848}
1849
1850impl From<InterpolateMode> for InterpolateModeIr {
1851    fn from(val: InterpolateMode) -> Self {
1852        match val {
1853            InterpolateMode::Nearest => Self::Nearest,
1854            InterpolateMode::Bilinear => Self::Bilinear,
1855            InterpolateMode::Bicubic => Self::Bicubic,
1856            InterpolateMode::Lanczos3 => Self::Lanczos3,
1857        }
1858    }
1859}
1860
1861impl From<InterpolateOptions> for InterpolateOptionsIr {
1862    fn from(val: InterpolateOptions) -> Self {
1863        Self {
1864            mode: val.mode.into(),
1865            align_corners: val.align_corners,
1866        }
1867    }
1868}
1869
1870#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1871#[allow(missing_docs)]
1872pub struct InterpolateBackwardOpIr {
1873    pub x: TensorIr,
1874    pub grad: TensorIr,
1875    pub output_size: [usize; 2],
1876    pub options: InterpolateOptionsIr,
1877    pub out: TensorIr,
1878}
1879
1880#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1881#[allow(missing_docs)]
1882pub enum GridSamplePaddingModeIr {
1883    Zeros,
1884    Border,
1885    Reflection,
1886}
1887
1888#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1889#[allow(missing_docs)]
1890pub struct GridSampleOptionsIr {
1891    pub mode: InterpolateModeIr,
1892    pub padding_mode: GridSamplePaddingModeIr,
1893    pub align_corners: bool,
1894}
1895
1896#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1897#[allow(missing_docs)]
1898pub struct GridSample2dOpIr {
1899    pub tensor: TensorIr,
1900    pub grid: TensorIr,
1901    pub options: GridSampleOptionsIr,
1902    pub out: TensorIr,
1903}
1904
1905impl From<GridSamplePaddingModeIr> for GridSamplePaddingMode {
1906    fn from(val: GridSamplePaddingModeIr) -> Self {
1907        match val {
1908            GridSamplePaddingModeIr::Zeros => Self::Zeros,
1909            GridSamplePaddingModeIr::Border => Self::Border,
1910            GridSamplePaddingModeIr::Reflection => Self::Reflection,
1911        }
1912    }
1913}
1914
1915impl From<GridSamplePaddingMode> for GridSamplePaddingModeIr {
1916    fn from(val: GridSamplePaddingMode) -> Self {
1917        match val {
1918            GridSamplePaddingMode::Zeros => Self::Zeros,
1919            GridSamplePaddingMode::Border => Self::Border,
1920            GridSamplePaddingMode::Reflection => Self::Reflection,
1921        }
1922    }
1923}
1924
1925impl From<GridSampleOptionsIr> for GridSampleOptions {
1926    fn from(val: GridSampleOptionsIr) -> Self {
1927        Self {
1928            mode: val.mode.into(),
1929            padding_mode: val.padding_mode.into(),
1930            align_corners: val.align_corners,
1931        }
1932    }
1933}
1934
1935impl From<GridSampleOptions> for GridSampleOptionsIr {
1936    fn from(val: GridSampleOptions) -> Self {
1937        Self {
1938            mode: val.mode.into(),
1939            padding_mode: val.padding_mode.into(),
1940            align_corners: val.align_corners,
1941        }
1942    }
1943}
1944
1945impl OperationIr {
1946    /// Get all input [tensors](TensorIr) involved with the current operation.
1947    pub fn inputs(&self) -> impl Iterator<Item = &TensorIr> {
1948        match self {
1949            OperationIr::BaseFloat(repr) => repr.inputs(),
1950            OperationIr::BaseInt(repr) => repr.inputs(),
1951            OperationIr::BaseBool(repr) => repr.inputs(),
1952            OperationIr::NumericFloat(_dtype, repr) => repr.inputs(),
1953            OperationIr::NumericInt(_dtype, repr) => repr.inputs(),
1954            OperationIr::Bool(repr) => repr.inputs(),
1955            OperationIr::Int(repr) => repr.inputs(),
1956            OperationIr::Float(_dtype, repr) => repr.inputs(),
1957            OperationIr::Module(repr) => repr.inputs(),
1958            OperationIr::Init(repr) => repr.inputs(),
1959            OperationIr::Custom(repr) => repr.inputs(),
1960            OperationIr::Drop(repr) => Box::new([repr].into_iter()),
1961            #[cfg(feature = "distributed")]
1962            OperationIr::Distributed(repr) => repr.inputs(),
1963        }
1964    }
1965
1966    /// Get all output [tensors](TensorIr) involved with the current operation.
1967    pub fn outputs(&self) -> impl Iterator<Item = &TensorIr> {
1968        match self {
1969            OperationIr::BaseFloat(repr) => repr.outputs(),
1970            OperationIr::BaseInt(repr) => repr.outputs(),
1971            OperationIr::BaseBool(repr) => repr.outputs(),
1972            OperationIr::NumericFloat(_dtype, repr) => repr.outputs(),
1973            OperationIr::NumericInt(_dtype, repr) => repr.outputs(),
1974            OperationIr::Bool(repr) => repr.outputs(),
1975            OperationIr::Int(repr) => repr.outputs(),
1976            OperationIr::Float(_dtype, repr) => repr.outputs(),
1977            OperationIr::Module(repr) => repr.outputs(),
1978            OperationIr::Init(repr) => repr.outputs(),
1979            OperationIr::Custom(repr) => repr.outputs(),
1980            OperationIr::Drop(_repr) => Box::new([].into_iter()),
1981            #[cfg(feature = "distributed")]
1982            OperationIr::Distributed(repr) => repr.outputs(),
1983        }
1984    }
1985
1986    /// Get all [tensor](TensorIr) involved with the current operation.
1987    pub fn nodes(&self) -> Vec<&TensorIr> {
1988        self.inputs().chain(self.outputs()).collect()
1989    }
1990
1991    /// Set the given nodes that are [read write](super::TensorStatus::ReadWrite) to
1992    /// [read only](super::TensorStatus::ReadOnly) in the current operation.
1993    ///
1994    /// Returns the tensor that were updated with their original representation.
1995    pub fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
1996        match self {
1997            OperationIr::BaseFloat(repr) => repr.mark_read_only(nodes),
1998            OperationIr::BaseInt(repr) => repr.mark_read_only(nodes),
1999            OperationIr::BaseBool(repr) => repr.mark_read_only(nodes),
2000            OperationIr::NumericFloat(_dtype, repr) => repr.mark_read_only(nodes),
2001            OperationIr::NumericInt(_dtype, repr) => repr.mark_read_only(nodes),
2002            OperationIr::Bool(repr) => repr.mark_read_only(nodes),
2003            OperationIr::Int(repr) => repr.mark_read_only(nodes),
2004            OperationIr::Float(_dtype, repr) => repr.mark_read_only(nodes),
2005            OperationIr::Module(repr) => repr.mark_read_only(nodes),
2006            OperationIr::Init(_) => Vec::new(),
2007            OperationIr::Drop(repr) => {
2008                let mut output = Vec::new();
2009                repr.mark_read_only(nodes, &mut output);
2010                output
2011            }
2012            OperationIr::Custom(repr) => {
2013                let mut output = Vec::new();
2014
2015                for input in repr.inputs.iter_mut() {
2016                    input.mark_read_only(nodes, &mut output);
2017                }
2018
2019                output
2020            }
2021            #[cfg(feature = "distributed")]
2022            OperationIr::Distributed(repr) => repr.mark_read_only(nodes),
2023        }
2024    }
2025}
2026
2027impl BaseOperationIr {
2028    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2029        match self {
2030            BaseOperationIr::Reshape(repr) => Box::new([&repr.input].into_iter()),
2031            BaseOperationIr::SwapDims(repr) => Box::new([&repr.input].into_iter()),
2032            BaseOperationIr::Permute(repr) => Box::new([&repr.input].into_iter()),
2033            BaseOperationIr::Expand(repr) => Box::new([&repr.input].into_iter()),
2034            BaseOperationIr::Flip(repr) => Box::new([&repr.input].into_iter()),
2035            BaseOperationIr::Slice(repr) => Box::new([&repr.tensor].into_iter()),
2036            BaseOperationIr::SliceAssign(repr) => Box::new([&repr.tensor, &repr.value].into_iter()),
2037            BaseOperationIr::Gather(repr) => Box::new([&repr.tensor, &repr.indices].into_iter()),
2038            BaseOperationIr::Scatter(repr) => {
2039                Box::new([&repr.tensor, &repr.indices, &repr.value].into_iter())
2040            }
2041            BaseOperationIr::ScatterNd(repr) => {
2042                Box::new([&repr.data, &repr.indices, &repr.values].into_iter())
2043            }
2044            BaseOperationIr::GatherNd(repr) => Box::new([&repr.data, &repr.indices].into_iter()),
2045            BaseOperationIr::Select(repr) => Box::new([&repr.tensor, &repr.indices].into_iter()),
2046            BaseOperationIr::SelectAssign(repr) => {
2047                Box::new([&repr.tensor, &repr.indices, &repr.value].into_iter())
2048            }
2049            BaseOperationIr::MaskWhere(repr) => {
2050                Box::new([&repr.tensor, &repr.mask, &repr.value].into_iter())
2051            }
2052            BaseOperationIr::MaskFill(repr) => Box::new([&repr.tensor, &repr.mask].into_iter()),
2053            BaseOperationIr::Equal(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2054            BaseOperationIr::EqualElem(repr) => Box::new([&repr.lhs].into_iter()),
2055            BaseOperationIr::RepeatDim(repr) => Box::new([&repr.tensor].into_iter()),
2056            BaseOperationIr::Cat(repr) => Box::new(repr.tensors.iter()),
2057            BaseOperationIr::Cast(repr) => Box::new([&repr.input].into_iter()),
2058            BaseOperationIr::Unfold(repr) => Box::new([&repr.input].into_iter()),
2059            BaseOperationIr::Empty(_repr) => Box::new([].into_iter()),
2060            BaseOperationIr::Ones(_repr) => Box::new([].into_iter()),
2061            BaseOperationIr::Zeros(_repr) => Box::new([].into_iter()),
2062        }
2063    }
2064
2065    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2066        match self {
2067            BaseOperationIr::Reshape(repr) => Box::new([&repr.out].into_iter()),
2068            BaseOperationIr::SwapDims(repr) => Box::new([&repr.out].into_iter()),
2069            BaseOperationIr::Permute(repr) => Box::new([&repr.out].into_iter()),
2070            BaseOperationIr::Expand(repr) => Box::new([&repr.out].into_iter()),
2071            BaseOperationIr::Flip(repr) => Box::new([&repr.out].into_iter()),
2072            BaseOperationIr::Slice(repr) => Box::new([&repr.out].into_iter()),
2073            BaseOperationIr::SliceAssign(repr) => Box::new([&repr.out].into_iter()),
2074            BaseOperationIr::Gather(repr) => Box::new([&repr.out].into_iter()),
2075            BaseOperationIr::Scatter(repr) => Box::new([&repr.out].into_iter()),
2076            BaseOperationIr::ScatterNd(repr) => Box::new([&repr.out].into_iter()),
2077            BaseOperationIr::GatherNd(repr) => Box::new([&repr.out].into_iter()),
2078            BaseOperationIr::Select(repr) => Box::new([&repr.out].into_iter()),
2079            BaseOperationIr::SelectAssign(repr) => Box::new([&repr.out].into_iter()),
2080            BaseOperationIr::MaskWhere(repr) => Box::new([&repr.out].into_iter()),
2081            BaseOperationIr::MaskFill(repr) => Box::new([&repr.out].into_iter()),
2082            BaseOperationIr::Equal(repr) => Box::new([&repr.out].into_iter()),
2083            BaseOperationIr::EqualElem(repr) => Box::new([&repr.out].into_iter()),
2084            BaseOperationIr::RepeatDim(repr) => Box::new([&repr.out].into_iter()),
2085            BaseOperationIr::Cat(repr) => Box::new([&repr.out].into_iter()),
2086            BaseOperationIr::Cast(repr) => Box::new([&repr.out].into_iter()),
2087            BaseOperationIr::Unfold(repr) => Box::new([&repr.out].into_iter()),
2088            BaseOperationIr::Empty(repr) => Box::new([&repr.out].into_iter()),
2089            BaseOperationIr::Ones(repr) => Box::new([&repr.out].into_iter()),
2090            BaseOperationIr::Zeros(repr) => Box::new([&repr.out].into_iter()),
2091        }
2092    }
2093
2094    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2095        let mut output = Vec::new();
2096
2097        match self {
2098            BaseOperationIr::Reshape(repr) => {
2099                repr.input.mark_read_only(nodes, &mut output);
2100            }
2101            BaseOperationIr::SwapDims(repr) => {
2102                repr.input.mark_read_only(nodes, &mut output);
2103            }
2104            BaseOperationIr::Permute(repr) => {
2105                repr.input.mark_read_only(nodes, &mut output);
2106            }
2107
2108            BaseOperationIr::Expand(repr) => {
2109                repr.input.mark_read_only(nodes, &mut output);
2110            }
2111
2112            BaseOperationIr::Flip(repr) => {
2113                repr.input.mark_read_only(nodes, &mut output);
2114            }
2115            BaseOperationIr::Slice(repr) => {
2116                repr.tensor.mark_read_only(nodes, &mut output);
2117            }
2118            BaseOperationIr::SliceAssign(repr) => {
2119                repr.tensor.mark_read_only(nodes, &mut output);
2120                repr.value.mark_read_only(nodes, &mut output);
2121            }
2122            BaseOperationIr::Gather(repr) => {
2123                repr.tensor.mark_read_only(nodes, &mut output);
2124                repr.indices.mark_read_only(nodes, &mut output);
2125            }
2126            BaseOperationIr::Scatter(repr) => {
2127                repr.tensor.mark_read_only(nodes, &mut output);
2128                repr.indices.mark_read_only(nodes, &mut output);
2129                repr.value.mark_read_only(nodes, &mut output);
2130            }
2131            BaseOperationIr::ScatterNd(repr) => {
2132                repr.data.mark_read_only(nodes, &mut output);
2133                repr.indices.mark_read_only(nodes, &mut output);
2134                repr.values.mark_read_only(nodes, &mut output);
2135            }
2136            BaseOperationIr::GatherNd(repr) => {
2137                repr.data.mark_read_only(nodes, &mut output);
2138                repr.indices.mark_read_only(nodes, &mut output);
2139            }
2140            BaseOperationIr::Select(repr) => {
2141                repr.tensor.mark_read_only(nodes, &mut output);
2142                repr.indices.mark_read_only(nodes, &mut output);
2143            }
2144            BaseOperationIr::SelectAssign(repr) => {
2145                repr.tensor.mark_read_only(nodes, &mut output);
2146                repr.indices.mark_read_only(nodes, &mut output);
2147                repr.value.mark_read_only(nodes, &mut output);
2148            }
2149            BaseOperationIr::MaskWhere(repr) => {
2150                repr.tensor.mark_read_only(nodes, &mut output);
2151                repr.mask.mark_read_only(nodes, &mut output);
2152                repr.value.mark_read_only(nodes, &mut output);
2153            }
2154            BaseOperationIr::MaskFill(repr) => {
2155                repr.tensor.mark_read_only(nodes, &mut output);
2156                repr.mask.mark_read_only(nodes, &mut output);
2157            }
2158            BaseOperationIr::Equal(repr) => {
2159                repr.lhs.mark_read_only(nodes, &mut output);
2160                repr.rhs.mark_read_only(nodes, &mut output);
2161            }
2162            BaseOperationIr::EqualElem(repr) => {
2163                repr.lhs.mark_read_only(nodes, &mut output);
2164            }
2165            BaseOperationIr::RepeatDim(repr) => {
2166                repr.tensor.mark_read_only(nodes, &mut output);
2167            }
2168            BaseOperationIr::Cat(repr) => {
2169                for t in repr.tensors.iter_mut() {
2170                    t.mark_read_only(nodes, &mut output);
2171                }
2172            }
2173            BaseOperationIr::Cast(repr) => {
2174                repr.input.mark_read_only(nodes, &mut output);
2175            }
2176            BaseOperationIr::Unfold(repr) => {
2177                repr.input.mark_read_only(nodes, &mut output);
2178            }
2179            BaseOperationIr::Empty(_) => {}
2180            BaseOperationIr::Zeros(_) => {}
2181            BaseOperationIr::Ones(_) => {}
2182        };
2183
2184        output
2185    }
2186}
2187
2188impl NumericOperationIr {
2189    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2190        match self {
2191            NumericOperationIr::Add(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2192            NumericOperationIr::AddScalar(repr) => Box::new([&repr.lhs].into_iter()),
2193            NumericOperationIr::Sub(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2194            NumericOperationIr::SubScalar(repr) => Box::new([&repr.lhs].into_iter()),
2195            NumericOperationIr::Mul(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2196            NumericOperationIr::MulScalar(repr) => Box::new([&repr.lhs].into_iter()),
2197            NumericOperationIr::Div(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2198            NumericOperationIr::DivScalar(repr) => Box::new([&repr.lhs].into_iter()),
2199            NumericOperationIr::Rem(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2200            NumericOperationIr::RemScalar(repr) => Box::new([&repr.lhs].into_iter()),
2201            NumericOperationIr::GreaterElem(repr) => Box::new([&repr.lhs].into_iter()),
2202            NumericOperationIr::GreaterEqualElem(repr) => Box::new([&repr.lhs].into_iter()),
2203            NumericOperationIr::LowerElem(repr) => Box::new([&repr.lhs].into_iter()),
2204            NumericOperationIr::LowerEqualElem(repr) => Box::new([&repr.lhs].into_iter()),
2205            NumericOperationIr::Greater(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2206            NumericOperationIr::GreaterEqual(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2207            NumericOperationIr::Lower(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2208            NumericOperationIr::LowerEqual(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2209            NumericOperationIr::ArgMax(repr) => Box::new([&repr.input].into_iter()),
2210            NumericOperationIr::ArgTopK(repr) => Box::new([&repr.input].into_iter()),
2211            NumericOperationIr::TopK(repr) => Box::new([&repr.input].into_iter()),
2212            NumericOperationIr::ArgMin(repr) => Box::new([&repr.input].into_iter()),
2213            NumericOperationIr::Clamp(repr) => Box::new([&repr.tensor].into_iter()),
2214            NumericOperationIr::Abs(repr) => Box::new([&repr.input].into_iter()),
2215            NumericOperationIr::Full(_repr) => Box::new([].into_iter()),
2216            NumericOperationIr::MeanDim(repr) => Box::new([&repr.input].into_iter()),
2217            NumericOperationIr::Mean(repr) => Box::new([&repr.input].into_iter()),
2218            NumericOperationIr::Sum(repr) => Box::new([&repr.input].into_iter()),
2219            NumericOperationIr::SumDim(repr) => Box::new([&repr.input].into_iter()),
2220            NumericOperationIr::Prod(repr) => Box::new([&repr.input].into_iter()),
2221            NumericOperationIr::ProdDim(repr) => Box::new([&repr.input].into_iter()),
2222            NumericOperationIr::Max(repr) => Box::new([&repr.input].into_iter()),
2223            NumericOperationIr::MaxDimWithIndices(repr) => Box::new([&repr.tensor].into_iter()),
2224            NumericOperationIr::MinDimWithIndices(repr) => Box::new([&repr.tensor].into_iter()),
2225            NumericOperationIr::Min(repr) => Box::new([&repr.input].into_iter()),
2226            NumericOperationIr::MaxDim(repr) => Box::new([&repr.input].into_iter()),
2227            NumericOperationIr::MinDim(repr) => Box::new([&repr.input].into_iter()),
2228            NumericOperationIr::MaxAbs(repr) => Box::new([&repr.input].into_iter()),
2229            NumericOperationIr::MaxAbsDim(repr) => Box::new([&repr.input].into_iter()),
2230            NumericOperationIr::IntRandom(_repr) => Box::new([].into_iter()),
2231            NumericOperationIr::Powi(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2232            NumericOperationIr::PowiScalar(repr) => Box::new([&repr.lhs].into_iter()),
2233            NumericOperationIr::CumMin(repr) => Box::new([&repr.input].into_iter()),
2234            NumericOperationIr::CumMax(repr) => Box::new([&repr.input].into_iter()),
2235            NumericOperationIr::CumProd(repr) => Box::new([&repr.input].into_iter()),
2236            NumericOperationIr::CumSum(repr) => Box::new([&repr.input].into_iter()),
2237        }
2238    }
2239
2240    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2241        match self {
2242            NumericOperationIr::Add(repr) => Box::new([&repr.out].into_iter()),
2243            NumericOperationIr::AddScalar(repr) => Box::new([&repr.out].into_iter()),
2244            NumericOperationIr::Sub(repr) => Box::new([&repr.out].into_iter()),
2245            NumericOperationIr::SubScalar(repr) => Box::new([&repr.out].into_iter()),
2246            NumericOperationIr::Mul(repr) => Box::new([&repr.out].into_iter()),
2247            NumericOperationIr::MulScalar(repr) => Box::new([&repr.out].into_iter()),
2248            NumericOperationIr::Div(repr) => Box::new([&repr.out].into_iter()),
2249            NumericOperationIr::DivScalar(repr) => Box::new([&repr.out].into_iter()),
2250            NumericOperationIr::Rem(repr) => Box::new([&repr.out].into_iter()),
2251            NumericOperationIr::RemScalar(repr) => Box::new([&repr.out].into_iter()),
2252            NumericOperationIr::GreaterElem(repr) => Box::new([&repr.out].into_iter()),
2253            NumericOperationIr::GreaterEqualElem(repr) => Box::new([&repr.out].into_iter()),
2254            NumericOperationIr::LowerElem(repr) => Box::new([&repr.out].into_iter()),
2255            NumericOperationIr::LowerEqualElem(repr) => Box::new([&repr.out].into_iter()),
2256            NumericOperationIr::Greater(repr) => Box::new([&repr.out].into_iter()),
2257            NumericOperationIr::GreaterEqual(repr) => Box::new([&repr.out].into_iter()),
2258            NumericOperationIr::Lower(repr) => Box::new([&repr.out].into_iter()),
2259            NumericOperationIr::LowerEqual(repr) => Box::new([&repr.out].into_iter()),
2260            NumericOperationIr::ArgMax(repr) => Box::new([&repr.out].into_iter()),
2261            NumericOperationIr::ArgTopK(repr) => Box::new([&repr.out].into_iter()),
2262            NumericOperationIr::TopK(repr) => Box::new([&repr.out].into_iter()),
2263            NumericOperationIr::ArgMin(repr) => Box::new([&repr.out].into_iter()),
2264            NumericOperationIr::Clamp(repr) => Box::new([&repr.out].into_iter()),
2265            NumericOperationIr::Abs(repr) => Box::new([&repr.out].into_iter()),
2266            NumericOperationIr::Full(repr) => Box::new([&repr.out].into_iter()),
2267            NumericOperationIr::MeanDim(repr) => Box::new([&repr.out].into_iter()),
2268            NumericOperationIr::Mean(repr) => Box::new([&repr.out].into_iter()),
2269            NumericOperationIr::Sum(repr) => Box::new([&repr.out].into_iter()),
2270            NumericOperationIr::SumDim(repr) => Box::new([&repr.out].into_iter()),
2271            NumericOperationIr::Prod(repr) => Box::new([&repr.out].into_iter()),
2272            NumericOperationIr::ProdDim(repr) => Box::new([&repr.out].into_iter()),
2273            NumericOperationIr::Max(repr) => Box::new([&repr.out].into_iter()),
2274            NumericOperationIr::MaxDimWithIndices(repr) => {
2275                Box::new([&repr.out, &repr.out_indices].into_iter())
2276            }
2277            NumericOperationIr::MinDimWithIndices(repr) => {
2278                Box::new([&repr.out, &repr.out_indices].into_iter())
2279            }
2280            NumericOperationIr::Min(repr) => Box::new([&repr.out].into_iter()),
2281            NumericOperationIr::MaxDim(repr) => Box::new([&repr.out].into_iter()),
2282            NumericOperationIr::MinDim(repr) => Box::new([&repr.out].into_iter()),
2283            NumericOperationIr::MaxAbs(repr) => Box::new([&repr.out].into_iter()),
2284            NumericOperationIr::MaxAbsDim(repr) => Box::new([&repr.out].into_iter()),
2285            NumericOperationIr::IntRandom(repr) => Box::new([&repr.out].into_iter()),
2286            NumericOperationIr::Powi(repr) => Box::new([&repr.out].into_iter()),
2287            NumericOperationIr::PowiScalar(repr) => Box::new([&repr.out].into_iter()),
2288            NumericOperationIr::CumMin(repr) => Box::new([&repr.out].into_iter()),
2289            NumericOperationIr::CumMax(repr) => Box::new([&repr.out].into_iter()),
2290            NumericOperationIr::CumProd(repr) => Box::new([&repr.out].into_iter()),
2291            NumericOperationIr::CumSum(repr) => Box::new([&repr.out].into_iter()),
2292        }
2293    }
2294    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2295        let mut output = Vec::new();
2296
2297        match self {
2298            NumericOperationIr::Add(repr) => {
2299                repr.lhs.mark_read_only(nodes, &mut output);
2300                repr.rhs.mark_read_only(nodes, &mut output);
2301            }
2302            NumericOperationIr::AddScalar(repr) => {
2303                repr.lhs.mark_read_only(nodes, &mut output);
2304            }
2305            NumericOperationIr::Sub(repr) => {
2306                repr.lhs.mark_read_only(nodes, &mut output);
2307                repr.rhs.mark_read_only(nodes, &mut output);
2308            }
2309            NumericOperationIr::SubScalar(repr) => {
2310                repr.lhs.mark_read_only(nodes, &mut output);
2311            }
2312            NumericOperationIr::Mul(repr) => {
2313                repr.lhs.mark_read_only(nodes, &mut output);
2314                repr.rhs.mark_read_only(nodes, &mut output);
2315            }
2316            NumericOperationIr::MulScalar(repr) => {
2317                repr.lhs.mark_read_only(nodes, &mut output);
2318            }
2319            NumericOperationIr::Div(repr) => {
2320                repr.lhs.mark_read_only(nodes, &mut output);
2321                repr.rhs.mark_read_only(nodes, &mut output);
2322            }
2323            NumericOperationIr::DivScalar(repr) => {
2324                repr.lhs.mark_read_only(nodes, &mut output);
2325            }
2326            NumericOperationIr::Rem(repr) => {
2327                repr.lhs.mark_read_only(nodes, &mut output);
2328                repr.rhs.mark_read_only(nodes, &mut output);
2329            }
2330            NumericOperationIr::RemScalar(repr) => {
2331                repr.lhs.mark_read_only(nodes, &mut output);
2332            }
2333            NumericOperationIr::GreaterElem(repr) => {
2334                repr.lhs.mark_read_only(nodes, &mut output);
2335            }
2336            NumericOperationIr::GreaterEqualElem(repr) => {
2337                repr.lhs.mark_read_only(nodes, &mut output);
2338            }
2339            NumericOperationIr::LowerElem(repr) => {
2340                repr.lhs.mark_read_only(nodes, &mut output);
2341            }
2342            NumericOperationIr::LowerEqualElem(repr) => {
2343                repr.lhs.mark_read_only(nodes, &mut output);
2344            }
2345            NumericOperationIr::Greater(repr) => {
2346                repr.lhs.mark_read_only(nodes, &mut output);
2347                repr.rhs.mark_read_only(nodes, &mut output);
2348            }
2349            NumericOperationIr::GreaterEqual(repr) => {
2350                repr.lhs.mark_read_only(nodes, &mut output);
2351                repr.rhs.mark_read_only(nodes, &mut output);
2352            }
2353            NumericOperationIr::Lower(repr) => {
2354                repr.lhs.mark_read_only(nodes, &mut output);
2355                repr.rhs.mark_read_only(nodes, &mut output);
2356            }
2357            NumericOperationIr::LowerEqual(repr) => {
2358                repr.lhs.mark_read_only(nodes, &mut output);
2359                repr.rhs.mark_read_only(nodes, &mut output);
2360            }
2361            NumericOperationIr::ArgMax(repr) => {
2362                repr.input.mark_read_only(nodes, &mut output);
2363            }
2364            NumericOperationIr::ArgTopK(repr) => {
2365                repr.input.mark_read_only(nodes, &mut output);
2366            }
2367            NumericOperationIr::TopK(repr) => {
2368                repr.input.mark_read_only(nodes, &mut output);
2369            }
2370            NumericOperationIr::ArgMin(repr) => {
2371                repr.input.mark_read_only(nodes, &mut output);
2372            }
2373            NumericOperationIr::Clamp(repr) => {
2374                repr.tensor.mark_read_only(nodes, &mut output);
2375            }
2376            NumericOperationIr::Abs(repr) => {
2377                repr.input.mark_read_only(nodes, &mut output);
2378            }
2379            NumericOperationIr::Full(_) => {}
2380            NumericOperationIr::MeanDim(repr) => {
2381                repr.input.mark_read_only(nodes, &mut output);
2382            }
2383            NumericOperationIr::Mean(repr) => {
2384                repr.input.mark_read_only(nodes, &mut output);
2385            }
2386            NumericOperationIr::Sum(repr) => {
2387                repr.input.mark_read_only(nodes, &mut output);
2388            }
2389            NumericOperationIr::SumDim(repr) => {
2390                repr.input.mark_read_only(nodes, &mut output);
2391            }
2392            NumericOperationIr::Prod(repr) => {
2393                repr.input.mark_read_only(nodes, &mut output);
2394            }
2395            NumericOperationIr::ProdDim(repr) => {
2396                repr.input.mark_read_only(nodes, &mut output);
2397            }
2398            NumericOperationIr::Max(repr) => {
2399                repr.input.mark_read_only(nodes, &mut output);
2400            }
2401            NumericOperationIr::MaxDimWithIndices(repr) => {
2402                repr.tensor.mark_read_only(nodes, &mut output);
2403            }
2404            NumericOperationIr::MinDimWithIndices(repr) => {
2405                repr.tensor.mark_read_only(nodes, &mut output);
2406            }
2407            NumericOperationIr::Min(repr) => {
2408                repr.input.mark_read_only(nodes, &mut output);
2409            }
2410            NumericOperationIr::MaxDim(repr) => {
2411                repr.input.mark_read_only(nodes, &mut output);
2412            }
2413            NumericOperationIr::MinDim(repr) => {
2414                repr.input.mark_read_only(nodes, &mut output);
2415            }
2416            NumericOperationIr::MaxAbs(repr) => {
2417                repr.input.mark_read_only(nodes, &mut output);
2418            }
2419            NumericOperationIr::MaxAbsDim(repr) => {
2420                repr.input.mark_read_only(nodes, &mut output);
2421            }
2422            NumericOperationIr::IntRandom(_) => {}
2423            NumericOperationIr::Powi(repr) => {
2424                repr.lhs.mark_read_only(nodes, &mut output);
2425                repr.rhs.mark_read_only(nodes, &mut output);
2426            }
2427            NumericOperationIr::PowiScalar(repr) => {
2428                repr.lhs.mark_read_only(nodes, &mut output);
2429            }
2430            NumericOperationIr::CumSum(repr) => {
2431                repr.input.mark_read_only(nodes, &mut output);
2432            }
2433            NumericOperationIr::CumProd(repr) => {
2434                repr.input.mark_read_only(nodes, &mut output);
2435            }
2436            NumericOperationIr::CumMin(repr) => {
2437                repr.input.mark_read_only(nodes, &mut output);
2438            }
2439            NumericOperationIr::CumMax(repr) => {
2440                repr.input.mark_read_only(nodes, &mut output);
2441            }
2442        };
2443
2444        output
2445    }
2446}
2447
2448impl FloatOperationIr {
2449    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2450        match self {
2451            FloatOperationIr::Matmul(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2452            FloatOperationIr::Cross(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2453            FloatOperationIr::Random(_repr) => Box::new([].into_iter()),
2454            FloatOperationIr::Exp(repr) => Box::new([&repr.input].into_iter()),
2455            FloatOperationIr::Log(repr) => Box::new([&repr.input].into_iter()),
2456            FloatOperationIr::Log1p(repr) => Box::new([&repr.input].into_iter()),
2457            FloatOperationIr::Erf(repr) => Box::new([&repr.input].into_iter()),
2458            FloatOperationIr::Recip(repr) => Box::new([&repr.input].into_iter()),
2459            FloatOperationIr::PowfScalar(repr) => Box::new([&repr.lhs].into_iter()),
2460            FloatOperationIr::Sqrt(repr) => Box::new([&repr.input].into_iter()),
2461            FloatOperationIr::Cos(repr) => Box::new([&repr.input].into_iter()),
2462            FloatOperationIr::Sin(repr) => Box::new([&repr.input].into_iter()),
2463            FloatOperationIr::Tanh(repr) => Box::new([&repr.input].into_iter()),
2464            FloatOperationIr::Round(repr) => Box::new([&repr.input].into_iter()),
2465            FloatOperationIr::Floor(repr) => Box::new([&repr.input].into_iter()),
2466            FloatOperationIr::Ceil(repr) => Box::new([&repr.input].into_iter()),
2467            FloatOperationIr::Trunc(repr) => Box::new([&repr.input].into_iter()),
2468            FloatOperationIr::IntoInt(repr) => Box::new([&repr.input].into_iter()),
2469            FloatOperationIr::Quantize(repr) => {
2470                Box::new([&repr.tensor, &repr.qparams.scales].into_iter())
2471            }
2472            FloatOperationIr::Dequantize(repr) => Box::new([&repr.input].into_iter()),
2473            FloatOperationIr::IsNan(repr) => Box::new([&repr.input].into_iter()),
2474            FloatOperationIr::IsInf(repr) => Box::new([&repr.input].into_iter()),
2475            FloatOperationIr::GridSample2d(repr) => {
2476                Box::new([&repr.tensor, &repr.grid].into_iter())
2477            }
2478            FloatOperationIr::Tan(repr) => Box::new([&repr.input].into_iter()),
2479            FloatOperationIr::Cosh(repr) => Box::new([&repr.input].into_iter()),
2480            FloatOperationIr::Sinh(repr) => Box::new([&repr.input].into_iter()),
2481            FloatOperationIr::ArcCos(repr) => Box::new([&repr.input].into_iter()),
2482            FloatOperationIr::ArcCosh(repr) => Box::new([&repr.input].into_iter()),
2483            FloatOperationIr::ArcSin(repr) => Box::new([&repr.input].into_iter()),
2484            FloatOperationIr::ArcSinh(repr) => Box::new([&repr.input].into_iter()),
2485            FloatOperationIr::ArcTan(repr) => Box::new([&repr.input].into_iter()),
2486            FloatOperationIr::ArcTanh(repr) => Box::new([&repr.input].into_iter()),
2487            FloatOperationIr::ArcTan2(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2488            FloatOperationIr::Powf(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2489        }
2490    }
2491    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2492        match self {
2493            FloatOperationIr::Matmul(repr) => Box::new([&repr.out].into_iter()),
2494            FloatOperationIr::Cross(repr) => Box::new([&repr.out].into_iter()),
2495            FloatOperationIr::Random(repr) => Box::new([&repr.out].into_iter()),
2496            FloatOperationIr::Exp(repr) => Box::new([&repr.out].into_iter()),
2497            FloatOperationIr::Log(repr) => Box::new([&repr.out].into_iter()),
2498            FloatOperationIr::Log1p(repr) => Box::new([&repr.out].into_iter()),
2499            FloatOperationIr::Erf(repr) => Box::new([&repr.out].into_iter()),
2500            FloatOperationIr::Recip(repr) => Box::new([&repr.out].into_iter()),
2501            FloatOperationIr::PowfScalar(repr) => Box::new([&repr.out].into_iter()),
2502            FloatOperationIr::Sqrt(repr) => Box::new([&repr.out].into_iter()),
2503            FloatOperationIr::Cos(repr) => Box::new([&repr.out].into_iter()),
2504            FloatOperationIr::Sin(repr) => Box::new([&repr.out].into_iter()),
2505            FloatOperationIr::Tanh(repr) => Box::new([&repr.out].into_iter()),
2506            FloatOperationIr::Round(repr) => Box::new([&repr.out].into_iter()),
2507            FloatOperationIr::Floor(repr) => Box::new([&repr.out].into_iter()),
2508            FloatOperationIr::Ceil(repr) => Box::new([&repr.out].into_iter()),
2509            FloatOperationIr::Trunc(repr) => Box::new([&repr.out].into_iter()),
2510            FloatOperationIr::IntoInt(repr) => Box::new([&repr.out].into_iter()),
2511            FloatOperationIr::Quantize(repr) => Box::new([&repr.out].into_iter()),
2512            FloatOperationIr::Dequantize(repr) => Box::new([&repr.out].into_iter()),
2513            FloatOperationIr::IsNan(repr) => Box::new([&repr.out].into_iter()),
2514            FloatOperationIr::IsInf(repr) => Box::new([&repr.out].into_iter()),
2515            FloatOperationIr::GridSample2d(repr) => Box::new([&repr.out].into_iter()),
2516            FloatOperationIr::Tan(repr) => Box::new([&repr.out].into_iter()),
2517            FloatOperationIr::Cosh(repr) => Box::new([&repr.out].into_iter()),
2518            FloatOperationIr::Sinh(repr) => Box::new([&repr.out].into_iter()),
2519            FloatOperationIr::ArcCos(repr) => Box::new([&repr.out].into_iter()),
2520            FloatOperationIr::ArcCosh(repr) => Box::new([&repr.out].into_iter()),
2521            FloatOperationIr::ArcSin(repr) => Box::new([&repr.out].into_iter()),
2522            FloatOperationIr::ArcSinh(repr) => Box::new([&repr.out].into_iter()),
2523            FloatOperationIr::ArcTan(repr) => Box::new([&repr.out].into_iter()),
2524            FloatOperationIr::ArcTanh(repr) => Box::new([&repr.out].into_iter()),
2525            FloatOperationIr::ArcTan2(repr) => Box::new([&repr.out].into_iter()),
2526            FloatOperationIr::Powf(repr) => Box::new([&repr.out].into_iter()),
2527        }
2528    }
2529
2530    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2531        let mut output = Vec::new();
2532
2533        match self {
2534            FloatOperationIr::Matmul(repr) => {
2535                repr.lhs.mark_read_only(nodes, &mut output);
2536                repr.rhs.mark_read_only(nodes, &mut output);
2537            }
2538            FloatOperationIr::Cross(repr) => {
2539                repr.lhs.mark_read_only(nodes, &mut output);
2540                repr.rhs.mark_read_only(nodes, &mut output);
2541            }
2542            FloatOperationIr::Random(_) => {}
2543            FloatOperationIr::Exp(repr) => {
2544                repr.input.mark_read_only(nodes, &mut output);
2545            }
2546            FloatOperationIr::Log(repr) => {
2547                repr.input.mark_read_only(nodes, &mut output);
2548            }
2549            FloatOperationIr::Log1p(repr) => {
2550                repr.input.mark_read_only(nodes, &mut output);
2551            }
2552            FloatOperationIr::Erf(repr) => {
2553                repr.input.mark_read_only(nodes, &mut output);
2554            }
2555            FloatOperationIr::Recip(repr) => {
2556                repr.input.mark_read_only(nodes, &mut output);
2557            }
2558            FloatOperationIr::PowfScalar(repr) => {
2559                repr.lhs.mark_read_only(nodes, &mut output);
2560            }
2561            FloatOperationIr::Sqrt(repr) => {
2562                repr.input.mark_read_only(nodes, &mut output);
2563            }
2564            FloatOperationIr::Cos(repr) => {
2565                repr.input.mark_read_only(nodes, &mut output);
2566            }
2567            FloatOperationIr::Sin(repr) => {
2568                repr.input.mark_read_only(nodes, &mut output);
2569            }
2570            FloatOperationIr::Tanh(repr) => {
2571                repr.input.mark_read_only(nodes, &mut output);
2572            }
2573            FloatOperationIr::Round(repr) => {
2574                repr.input.mark_read_only(nodes, &mut output);
2575            }
2576            FloatOperationIr::Floor(repr) => {
2577                repr.input.mark_read_only(nodes, &mut output);
2578            }
2579            FloatOperationIr::Ceil(repr) => {
2580                repr.input.mark_read_only(nodes, &mut output);
2581            }
2582            FloatOperationIr::Trunc(repr) => {
2583                repr.input.mark_read_only(nodes, &mut output);
2584            }
2585            FloatOperationIr::Quantize(repr) => {
2586                repr.tensor.mark_read_only(nodes, &mut output);
2587                repr.qparams.scales.mark_read_only(nodes, &mut output);
2588            }
2589            FloatOperationIr::Dequantize(repr) => {
2590                repr.input.mark_read_only(nodes, &mut output);
2591            }
2592            FloatOperationIr::IntoInt(repr) => {
2593                repr.input.mark_read_only(nodes, &mut output);
2594            }
2595            FloatOperationIr::IsNan(repr) => {
2596                repr.input.mark_read_only(nodes, &mut output);
2597            }
2598            FloatOperationIr::IsInf(repr) => {
2599                repr.input.mark_read_only(nodes, &mut output);
2600            }
2601            FloatOperationIr::GridSample2d(repr) => {
2602                repr.tensor.mark_read_only(nodes, &mut output);
2603                repr.grid.mark_read_only(nodes, &mut output);
2604            }
2605            FloatOperationIr::Tan(repr) => repr.input.mark_read_only(nodes, &mut output),
2606            FloatOperationIr::Cosh(repr) => repr.input.mark_read_only(nodes, &mut output),
2607            FloatOperationIr::Sinh(repr) => repr.input.mark_read_only(nodes, &mut output),
2608            FloatOperationIr::ArcCos(repr) => repr.input.mark_read_only(nodes, &mut output),
2609            FloatOperationIr::ArcCosh(repr) => repr.input.mark_read_only(nodes, &mut output),
2610            FloatOperationIr::ArcSin(repr) => repr.input.mark_read_only(nodes, &mut output),
2611            FloatOperationIr::ArcSinh(repr) => repr.input.mark_read_only(nodes, &mut output),
2612            FloatOperationIr::ArcTan(repr) => repr.input.mark_read_only(nodes, &mut output),
2613            FloatOperationIr::ArcTanh(repr) => repr.input.mark_read_only(nodes, &mut output),
2614            FloatOperationIr::ArcTan2(repr) => {
2615                repr.lhs.mark_read_only(nodes, &mut output);
2616                repr.rhs.mark_read_only(nodes, &mut output);
2617            }
2618            FloatOperationIr::Powf(repr) => {
2619                repr.lhs.mark_read_only(nodes, &mut output);
2620                repr.rhs.mark_read_only(nodes, &mut output);
2621            }
2622        };
2623
2624        output
2625    }
2626}
2627
2628impl IntOperationIr {
2629    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2630        match self {
2631            IntOperationIr::Matmul(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2632            IntOperationIr::IntoFloat(repr) => Box::new([&repr.input].into_iter()),
2633            IntOperationIr::BitwiseAnd(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2634            IntOperationIr::BitwiseAndScalar(repr) => Box::new([&repr.lhs].into_iter()),
2635            IntOperationIr::BitwiseOr(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2636            IntOperationIr::BitwiseOrScalar(repr) => Box::new([&repr.lhs].into_iter()),
2637            IntOperationIr::BitwiseXor(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2638            IntOperationIr::BitwiseXorScalar(repr) => Box::new([&repr.lhs].into_iter()),
2639            IntOperationIr::BitwiseNot(repr) => Box::new([&repr.input].into_iter()),
2640            IntOperationIr::BitwiseLeftShift(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2641            IntOperationIr::BitwiseLeftShiftScalar(repr) => Box::new([&repr.lhs].into_iter()),
2642            IntOperationIr::BitwiseRightShift(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2643            IntOperationIr::BitwiseRightShiftScalar(repr) => Box::new([&repr.lhs].into_iter()),
2644        }
2645    }
2646
2647    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2648        match self {
2649            IntOperationIr::Matmul(repr) => Box::new([&repr.out].into_iter()),
2650            IntOperationIr::IntoFloat(repr) => Box::new([&repr.out].into_iter()),
2651            IntOperationIr::BitwiseAnd(repr) => Box::new([&repr.out].into_iter()),
2652            IntOperationIr::BitwiseAndScalar(repr) => Box::new([&repr.out].into_iter()),
2653            IntOperationIr::BitwiseOr(repr) => Box::new([&repr.out].into_iter()),
2654            IntOperationIr::BitwiseOrScalar(repr) => Box::new([&repr.out].into_iter()),
2655            IntOperationIr::BitwiseXor(repr) => Box::new([&repr.out].into_iter()),
2656            IntOperationIr::BitwiseXorScalar(repr) => Box::new([&repr.out].into_iter()),
2657            IntOperationIr::BitwiseNot(repr) => Box::new([&repr.out].into_iter()),
2658            IntOperationIr::BitwiseLeftShift(repr) => Box::new([&repr.out].into_iter()),
2659            IntOperationIr::BitwiseLeftShiftScalar(repr) => Box::new([&repr.out].into_iter()),
2660            IntOperationIr::BitwiseRightShift(repr) => Box::new([&repr.out].into_iter()),
2661            IntOperationIr::BitwiseRightShiftScalar(repr) => Box::new([&repr.out].into_iter()),
2662        }
2663    }
2664
2665    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2666        let mut output = Vec::new();
2667
2668        match self {
2669            IntOperationIr::Matmul(repr) => {
2670                repr.lhs.mark_read_only(nodes, &mut output);
2671                repr.rhs.mark_read_only(nodes, &mut output);
2672            }
2673            IntOperationIr::IntoFloat(repr) => {
2674                repr.input.mark_read_only(nodes, &mut output);
2675            }
2676            IntOperationIr::BitwiseAnd(repr) => {
2677                repr.lhs.mark_read_only(nodes, &mut output);
2678                repr.rhs.mark_read_only(nodes, &mut output);
2679            }
2680            IntOperationIr::BitwiseAndScalar(repr) => {
2681                repr.lhs.mark_read_only(nodes, &mut output);
2682            }
2683            IntOperationIr::BitwiseOr(repr) => {
2684                repr.lhs.mark_read_only(nodes, &mut output);
2685                repr.rhs.mark_read_only(nodes, &mut output);
2686            }
2687            IntOperationIr::BitwiseOrScalar(repr) => {
2688                repr.lhs.mark_read_only(nodes, &mut output);
2689            }
2690            IntOperationIr::BitwiseXor(repr) => {
2691                repr.lhs.mark_read_only(nodes, &mut output);
2692                repr.rhs.mark_read_only(nodes, &mut output);
2693            }
2694            IntOperationIr::BitwiseXorScalar(repr) => {
2695                repr.lhs.mark_read_only(nodes, &mut output);
2696            }
2697            IntOperationIr::BitwiseNot(repr) => {
2698                repr.input.mark_read_only(nodes, &mut output);
2699            }
2700            IntOperationIr::BitwiseLeftShift(repr) => {
2701                repr.lhs.mark_read_only(nodes, &mut output);
2702                repr.rhs.mark_read_only(nodes, &mut output);
2703            }
2704            IntOperationIr::BitwiseLeftShiftScalar(repr) => {
2705                repr.lhs.mark_read_only(nodes, &mut output);
2706            }
2707            IntOperationIr::BitwiseRightShift(repr) => {
2708                repr.lhs.mark_read_only(nodes, &mut output);
2709                repr.rhs.mark_read_only(nodes, &mut output);
2710            }
2711            IntOperationIr::BitwiseRightShiftScalar(repr) => {
2712                repr.lhs.mark_read_only(nodes, &mut output);
2713            }
2714        };
2715
2716        output
2717    }
2718}
2719
2720impl BoolOperationIr {
2721    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2722        match self {
2723            BoolOperationIr::IntoFloat(repr) => Box::new([&repr.input].into_iter()),
2724            BoolOperationIr::IntoInt(repr) => Box::new([&repr.input].into_iter()),
2725            BoolOperationIr::Not(repr) => Box::new([&repr.input].into_iter()),
2726            BoolOperationIr::And(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2727            BoolOperationIr::Or(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2728        }
2729    }
2730    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2731        match self {
2732            BoolOperationIr::IntoFloat(repr) => Box::new([&repr.out].into_iter()),
2733            BoolOperationIr::IntoInt(repr) => Box::new([&repr.out].into_iter()),
2734            BoolOperationIr::Not(repr) => Box::new([&repr.out].into_iter()),
2735            BoolOperationIr::And(repr) => Box::new([&repr.out].into_iter()),
2736            BoolOperationIr::Or(repr) => Box::new([&repr.out].into_iter()),
2737        }
2738    }
2739    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2740        let mut output = Vec::new();
2741
2742        match self {
2743            BoolOperationIr::IntoFloat(repr) => {
2744                repr.input.mark_read_only(nodes, &mut output);
2745            }
2746            BoolOperationIr::IntoInt(repr) => {
2747                repr.input.mark_read_only(nodes, &mut output);
2748            }
2749            BoolOperationIr::Not(repr) => {
2750                repr.input.mark_read_only(nodes, &mut output);
2751            }
2752            BoolOperationIr::And(repr) => {
2753                repr.lhs.mark_read_only(nodes, &mut output);
2754                repr.rhs.mark_read_only(nodes, &mut output);
2755            }
2756            BoolOperationIr::Or(repr) => {
2757                repr.lhs.mark_read_only(nodes, &mut output);
2758                repr.rhs.mark_read_only(nodes, &mut output);
2759            }
2760        };
2761
2762        output
2763    }
2764}
2765
2766impl ModuleOperationIr {
2767    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2768        match self {
2769            ModuleOperationIr::Embedding(repr) => {
2770                Box::new([&repr.weights, &repr.indices].into_iter())
2771            }
2772            ModuleOperationIr::EmbeddingBackward(repr) => {
2773                Box::new([&repr.weights, &repr.out_grad, &repr.indices].into_iter())
2774            }
2775            ModuleOperationIr::Linear(repr) => {
2776                if let Some(bias) = &repr.bias {
2777                    Box::new([&repr.x, &repr.weight, bias].into_iter())
2778                } else {
2779                    Box::new([&repr.x, &repr.weight].into_iter())
2780                }
2781            }
2782            ModuleOperationIr::LinearXBackward(repr) => {
2783                Box::new([&repr.weight, &repr.output_grad].into_iter())
2784            }
2785            ModuleOperationIr::LinearWeightBackward(repr) => {
2786                Box::new([&repr.x, &repr.output_grad].into_iter())
2787            }
2788            ModuleOperationIr::LinearBiasBackward(repr) => {
2789                Box::new([&repr.output_grad].into_iter())
2790            }
2791            ModuleOperationIr::Conv1d(repr) => {
2792                if let Some(bias) = &repr.bias {
2793                    Box::new([&repr.x, &repr.weight, bias].into_iter())
2794                } else {
2795                    Box::new([&repr.x, &repr.weight].into_iter())
2796                }
2797            }
2798            ModuleOperationIr::Conv1dXBackward(repr) => {
2799                Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter())
2800            }
2801            ModuleOperationIr::Conv1dWeightBackward(repr) => {
2802                Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter())
2803            }
2804            ModuleOperationIr::Conv1dBiasBackward(repr) => {
2805                Box::new([&repr.x, &repr.bias, &repr.output_grad].into_iter())
2806            }
2807            ModuleOperationIr::Conv2d(repr) => {
2808                if let Some(bias) = &repr.bias {
2809                    Box::new([&repr.x, &repr.weight, bias].into_iter())
2810                } else {
2811                    Box::new([&repr.x, &repr.weight].into_iter())
2812                }
2813            }
2814            ModuleOperationIr::Conv2dXBackward(repr) => {
2815                Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter())
2816            }
2817            ModuleOperationIr::Conv2dWeightBackward(repr) => {
2818                Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter())
2819            }
2820            ModuleOperationIr::Conv2dBiasBackward(repr) => {
2821                Box::new([&repr.x, &repr.bias, &repr.output_grad].into_iter())
2822            }
2823            ModuleOperationIr::Conv3d(repr) => {
2824                if let Some(bias) = &repr.bias {
2825                    Box::new([&repr.x, &repr.weight, bias].into_iter())
2826                } else {
2827                    Box::new([&repr.x, &repr.weight].into_iter())
2828                }
2829            }
2830            ModuleOperationIr::Conv3dXBackward(repr) => {
2831                Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter())
2832            }
2833            ModuleOperationIr::Conv3dWeightBackward(repr) => {
2834                Box::new([&repr.x, &repr.weight, &repr.output_grad].into_iter())
2835            }
2836            ModuleOperationIr::Conv3dBiasBackward(repr) => {
2837                Box::new([&repr.x, &repr.bias, &repr.output_grad].into_iter())
2838            }
2839            ModuleOperationIr::DeformableConv2d(repr) => match (&repr.mask, &repr.bias) {
2840                (Some(mask), Some(bias)) => {
2841                    Box::new([&repr.x, &repr.offset, &repr.weight, mask, bias].into_iter())
2842                }
2843                (Some(mask), None) => {
2844                    Box::new([&repr.x, &repr.offset, &repr.weight, mask].into_iter())
2845                }
2846                (None, Some(bias)) => {
2847                    Box::new([&repr.x, &repr.offset, &repr.weight, bias].into_iter())
2848                }
2849                (None, None) => Box::new([&repr.x, &repr.offset, &repr.weight].into_iter()),
2850            },
2851            ModuleOperationIr::DeformableConv2dBackward(repr) => match (&repr.mask, &repr.bias) {
2852                (Some(mask), Some(bias)) => Box::new(
2853                    [
2854                        &repr.x,
2855                        &repr.offset,
2856                        &repr.weight,
2857                        &repr.out_grad,
2858                        mask,
2859                        bias,
2860                    ]
2861                    .into_iter(),
2862                ),
2863                (Some(mask), None) => Box::new(
2864                    [&repr.x, &repr.offset, &repr.weight, &repr.out_grad, mask].into_iter(),
2865                ),
2866                (None, Some(bias)) => Box::new(
2867                    [&repr.x, &repr.offset, &repr.weight, &repr.out_grad, bias].into_iter(),
2868                ),
2869                (None, None) => {
2870                    Box::new([&repr.x, &repr.offset, &repr.weight, &repr.out_grad].into_iter())
2871                }
2872            },
2873            ModuleOperationIr::ConvTranspose1d(repr) => {
2874                if let Some(bias) = &repr.bias {
2875                    Box::new([&repr.x, &repr.weight, bias].into_iter())
2876                } else {
2877                    Box::new([&repr.x, &repr.weight].into_iter())
2878                }
2879            }
2880            ModuleOperationIr::ConvTranspose2d(repr) => {
2881                if let Some(bias) = &repr.bias {
2882                    Box::new([&repr.x, &repr.weight, bias].into_iter())
2883                } else {
2884                    Box::new([&repr.x, &repr.weight].into_iter())
2885                }
2886            }
2887            ModuleOperationIr::ConvTranspose3d(repr) => {
2888                if let Some(bias) = &repr.bias {
2889                    Box::new([&repr.x, &repr.weight, bias].into_iter())
2890                } else {
2891                    Box::new([&repr.x, &repr.weight].into_iter())
2892                }
2893            }
2894            ModuleOperationIr::AvgPool1d(repr) => Box::new([&repr.x].into_iter()),
2895            ModuleOperationIr::AvgPool2d(repr) => Box::new([&repr.x].into_iter()),
2896            ModuleOperationIr::AvgPool1dBackward(repr) => {
2897                Box::new([&repr.x, &repr.grad].into_iter())
2898            }
2899            ModuleOperationIr::AvgPool2dBackward(repr) => {
2900                Box::new([&repr.x, &repr.grad].into_iter())
2901            }
2902            ModuleOperationIr::AdaptiveAvgPool1d(repr) => Box::new([&repr.x].into_iter()),
2903            ModuleOperationIr::AdaptiveAvgPool2d(repr) => Box::new([&repr.x].into_iter()),
2904            ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => {
2905                Box::new([&repr.x, &repr.grad].into_iter())
2906            }
2907            ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => {
2908                Box::new([&repr.x, &repr.grad].into_iter())
2909            }
2910            ModuleOperationIr::MaxPool1d(repr) => Box::new([&repr.x].into_iter()),
2911            ModuleOperationIr::MaxPool1dWithIndices(repr) => Box::new([&repr.x].into_iter()),
2912            ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => {
2913                Box::new([&repr.x, &repr.indices, &repr.grad].into_iter())
2914            }
2915            ModuleOperationIr::MaxPool2d(repr) => Box::new([&repr.x].into_iter()),
2916            ModuleOperationIr::MaxPool2dWithIndices(repr) => Box::new([&repr.x].into_iter()),
2917            ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => {
2918                Box::new([&repr.x, &repr.indices, &repr.grad].into_iter())
2919            }
2920            ModuleOperationIr::Interpolate(repr) => Box::new([&repr.x].into_iter()),
2921            ModuleOperationIr::InterpolateBackward(repr) => {
2922                Box::new([&repr.x, &repr.grad].into_iter())
2923            }
2924            ModuleOperationIr::Rfft(repr) => Box::new([&repr.signal].into_iter()),
2925            ModuleOperationIr::IRfft(repr) => {
2926                Box::new([&repr.input_re, &repr.input_im].into_iter())
2927            }
2928            ModuleOperationIr::Attention(repr) => {
2929                if let Some(mask) = &repr.mask {
2930                    if let Some(attn_bias) = &repr.attn_bias {
2931                        Box::new([&repr.query, &repr.key, &repr.value, mask, attn_bias].into_iter())
2932                    } else {
2933                        Box::new([&repr.query, &repr.key, &repr.value, mask].into_iter())
2934                    }
2935                } else if let Some(attn_bias) = &repr.attn_bias {
2936                    Box::new([&repr.query, &repr.key, &repr.value, attn_bias].into_iter())
2937                } else {
2938                    Box::new([&repr.query, &repr.key, &repr.value].into_iter())
2939                }
2940            }
2941            ModuleOperationIr::CtcLoss(repr) => Box::new(
2942                [
2943                    &repr.log_probs,
2944                    &repr.targets,
2945                    &repr.input_lengths,
2946                    &repr.target_lengths,
2947                ]
2948                .into_iter(),
2949            ),
2950            ModuleOperationIr::CtcLossBackward(repr) => Box::new(
2951                [
2952                    &repr.log_probs,
2953                    &repr.targets,
2954                    &repr.input_lengths,
2955                    &repr.target_lengths,
2956                    &repr.grad_loss,
2957                ]
2958                .into_iter(),
2959            ),
2960        }
2961    }
2962    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2963        match self {
2964            ModuleOperationIr::Embedding(repr) => Box::new([&repr.out].into_iter()),
2965            ModuleOperationIr::EmbeddingBackward(repr) => Box::new([&repr.out].into_iter()),
2966            ModuleOperationIr::Linear(repr) => Box::new([&repr.out].into_iter()),
2967            ModuleOperationIr::LinearXBackward(repr) => Box::new([&repr.out].into_iter()),
2968            ModuleOperationIr::LinearWeightBackward(repr) => Box::new([&repr.out].into_iter()),
2969            ModuleOperationIr::LinearBiasBackward(repr) => Box::new([&repr.out].into_iter()),
2970            ModuleOperationIr::Conv1d(repr) => Box::new([&repr.out].into_iter()),
2971            ModuleOperationIr::Conv1dXBackward(repr) => Box::new([&repr.out].into_iter()),
2972            ModuleOperationIr::Conv1dWeightBackward(repr) => Box::new([&repr.out].into_iter()),
2973            ModuleOperationIr::Conv1dBiasBackward(repr) => Box::new([&repr.out].into_iter()),
2974            ModuleOperationIr::Conv2d(repr) => Box::new([&repr.out].into_iter()),
2975            ModuleOperationIr::Conv2dXBackward(repr) => Box::new([&repr.out].into_iter()),
2976            ModuleOperationIr::Conv2dWeightBackward(repr) => Box::new([&repr.out].into_iter()),
2977            ModuleOperationIr::Conv2dBiasBackward(repr) => Box::new([&repr.out].into_iter()),
2978            ModuleOperationIr::Conv3d(repr) => Box::new([&repr.out].into_iter()),
2979            ModuleOperationIr::Conv3dXBackward(repr) => Box::new([&repr.out].into_iter()),
2980            ModuleOperationIr::Conv3dWeightBackward(repr) => Box::new([&repr.out].into_iter()),
2981            ModuleOperationIr::Conv3dBiasBackward(repr) => Box::new([&repr.out].into_iter()),
2982            ModuleOperationIr::DeformableConv2d(repr) => Box::new([&repr.out].into_iter()),
2983            ModuleOperationIr::DeformableConv2dBackward(repr) => {
2984                match (&repr.mask_grad, &repr.bias_grad) {
2985                    (Some(mask_grad), Some(bias_grad)) => Box::new(
2986                        [
2987                            &repr.input_grad,
2988                            &repr.offset_grad,
2989                            &repr.weight_grad,
2990                            mask_grad,
2991                            bias_grad,
2992                        ]
2993                        .into_iter(),
2994                    ),
2995                    (Some(mask_grad), None) => Box::new(
2996                        [
2997                            &repr.input_grad,
2998                            &repr.offset_grad,
2999                            &repr.weight_grad,
3000                            mask_grad,
3001                        ]
3002                        .into_iter(),
3003                    ),
3004                    (None, Some(bias_grad)) => Box::new(
3005                        [
3006                            &repr.input_grad,
3007                            &repr.offset_grad,
3008                            &repr.weight_grad,
3009                            bias_grad,
3010                        ]
3011                        .into_iter(),
3012                    ),
3013                    (None, None) => Box::new(
3014                        [&repr.input_grad, &repr.offset_grad, &repr.weight_grad].into_iter(),
3015                    ),
3016                }
3017            }
3018            ModuleOperationIr::ConvTranspose1d(repr) => Box::new([&repr.out].into_iter()),
3019            ModuleOperationIr::ConvTranspose2d(repr) => Box::new([&repr.out].into_iter()),
3020            ModuleOperationIr::ConvTranspose3d(repr) => Box::new([&repr.out].into_iter()),
3021            ModuleOperationIr::AvgPool1d(repr) => Box::new([&repr.out].into_iter()),
3022            ModuleOperationIr::AvgPool2d(repr) => Box::new([&repr.out].into_iter()),
3023            ModuleOperationIr::AvgPool1dBackward(repr) => Box::new([&repr.out].into_iter()),
3024            ModuleOperationIr::AvgPool2dBackward(repr) => Box::new([&repr.out].into_iter()),
3025            ModuleOperationIr::AdaptiveAvgPool1d(repr) => Box::new([&repr.out].into_iter()),
3026            ModuleOperationIr::AdaptiveAvgPool2d(repr) => Box::new([&repr.out].into_iter()),
3027            ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => Box::new([&repr.out].into_iter()),
3028            ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => Box::new([&repr.out].into_iter()),
3029            ModuleOperationIr::MaxPool1d(repr) => Box::new([&repr.out].into_iter()),
3030            ModuleOperationIr::MaxPool1dWithIndices(repr) => {
3031                Box::new([&repr.out, &repr.out_indices].into_iter())
3032            }
3033            ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => {
3034                Box::new([&repr.out].into_iter())
3035            }
3036            ModuleOperationIr::MaxPool2d(repr) => Box::new([&repr.out].into_iter()),
3037            ModuleOperationIr::MaxPool2dWithIndices(repr) => {
3038                Box::new([&repr.out, &repr.out_indices].into_iter())
3039            }
3040            ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => {
3041                Box::new([&repr.out].into_iter())
3042            }
3043            ModuleOperationIr::Interpolate(repr) => Box::new([&repr.out].into_iter()),
3044            ModuleOperationIr::InterpolateBackward(repr) => Box::new([&repr.out].into_iter()),
3045            ModuleOperationIr::Rfft(repr) => Box::new([&repr.out_re, &repr.out_im].into_iter()),
3046            ModuleOperationIr::IRfft(repr) => Box::new([&repr.out_signal].into_iter()),
3047            ModuleOperationIr::Attention(repr) => Box::new([&repr.out].into_iter()),
3048            ModuleOperationIr::CtcLoss(repr) => Box::new([&repr.out].into_iter()),
3049            ModuleOperationIr::CtcLossBackward(repr) => Box::new([&repr.out].into_iter()),
3050        }
3051    }
3052
3053    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
3054        let mut output = Vec::new();
3055
3056        match self {
3057            ModuleOperationIr::Embedding(repr) => {
3058                repr.weights.mark_read_only(nodes, &mut output);
3059                repr.indices.mark_read_only(nodes, &mut output);
3060            }
3061            ModuleOperationIr::EmbeddingBackward(repr) => {
3062                repr.weights.mark_read_only(nodes, &mut output);
3063                repr.out_grad.mark_read_only(nodes, &mut output);
3064                repr.indices.mark_read_only(nodes, &mut output);
3065            }
3066            ModuleOperationIr::Linear(repr) => {
3067                repr.x.mark_read_only(nodes, &mut output);
3068                repr.weight.mark_read_only(nodes, &mut output);
3069
3070                if let Some(bias) = &mut repr.bias {
3071                    bias.mark_read_only(nodes, &mut output);
3072                }
3073            }
3074            ModuleOperationIr::LinearXBackward(repr) => {
3075                repr.weight.mark_read_only(nodes, &mut output);
3076                repr.output_grad.mark_read_only(nodes, &mut output);
3077            }
3078            ModuleOperationIr::LinearWeightBackward(repr) => {
3079                repr.x.mark_read_only(nodes, &mut output);
3080                repr.output_grad.mark_read_only(nodes, &mut output);
3081            }
3082            ModuleOperationIr::LinearBiasBackward(repr) => {
3083                repr.output_grad.mark_read_only(nodes, &mut output);
3084            }
3085            ModuleOperationIr::Conv1d(repr) => {
3086                repr.x.mark_read_only(nodes, &mut output);
3087                repr.weight.mark_read_only(nodes, &mut output);
3088
3089                if let Some(bias) = &mut repr.bias {
3090                    bias.mark_read_only(nodes, &mut output);
3091                }
3092            }
3093            ModuleOperationIr::Conv1dXBackward(repr) => {
3094                repr.x.mark_read_only(nodes, &mut output);
3095                repr.weight.mark_read_only(nodes, &mut output);
3096                repr.output_grad.mark_read_only(nodes, &mut output);
3097            }
3098            ModuleOperationIr::Conv1dWeightBackward(repr) => {
3099                repr.x.mark_read_only(nodes, &mut output);
3100                repr.weight.mark_read_only(nodes, &mut output);
3101                repr.output_grad.mark_read_only(nodes, &mut output);
3102            }
3103            ModuleOperationIr::Conv1dBiasBackward(repr) => {
3104                repr.x.mark_read_only(nodes, &mut output);
3105                repr.bias.mark_read_only(nodes, &mut output);
3106                repr.output_grad.mark_read_only(nodes, &mut output);
3107            }
3108            ModuleOperationIr::Conv2d(repr) => {
3109                repr.x.mark_read_only(nodes, &mut output);
3110                repr.weight.mark_read_only(nodes, &mut output);
3111
3112                if let Some(bias) = &mut repr.bias {
3113                    bias.mark_read_only(nodes, &mut output);
3114                }
3115            }
3116            ModuleOperationIr::Conv2dXBackward(repr) => {
3117                repr.x.mark_read_only(nodes, &mut output);
3118                repr.weight.mark_read_only(nodes, &mut output);
3119                repr.output_grad.mark_read_only(nodes, &mut output);
3120            }
3121            ModuleOperationIr::Conv2dWeightBackward(repr) => {
3122                repr.x.mark_read_only(nodes, &mut output);
3123                repr.weight.mark_read_only(nodes, &mut output);
3124                repr.output_grad.mark_read_only(nodes, &mut output);
3125            }
3126            ModuleOperationIr::Conv2dBiasBackward(repr) => {
3127                repr.x.mark_read_only(nodes, &mut output);
3128                repr.bias.mark_read_only(nodes, &mut output);
3129                repr.output_grad.mark_read_only(nodes, &mut output);
3130            }
3131            ModuleOperationIr::Conv3d(repr) => {
3132                repr.x.mark_read_only(nodes, &mut output);
3133                repr.weight.mark_read_only(nodes, &mut output);
3134
3135                if let Some(bias) = &mut repr.bias {
3136                    bias.mark_read_only(nodes, &mut output);
3137                }
3138            }
3139            ModuleOperationIr::Conv3dXBackward(repr) => {
3140                repr.x.mark_read_only(nodes, &mut output);
3141                repr.weight.mark_read_only(nodes, &mut output);
3142                repr.output_grad.mark_read_only(nodes, &mut output);
3143            }
3144            ModuleOperationIr::Conv3dWeightBackward(repr) => {
3145                repr.x.mark_read_only(nodes, &mut output);
3146                repr.weight.mark_read_only(nodes, &mut output);
3147                repr.output_grad.mark_read_only(nodes, &mut output);
3148            }
3149            ModuleOperationIr::Conv3dBiasBackward(repr) => {
3150                repr.x.mark_read_only(nodes, &mut output);
3151                repr.bias.mark_read_only(nodes, &mut output);
3152                repr.output_grad.mark_read_only(nodes, &mut output);
3153            }
3154            ModuleOperationIr::DeformableConv2d(repr) => {
3155                repr.x.mark_read_only(nodes, &mut output);
3156                repr.weight.mark_read_only(nodes, &mut output);
3157                repr.offset.mark_read_only(nodes, &mut output);
3158
3159                match (&mut repr.mask, &mut repr.bias) {
3160                    (Some(mask), Some(bias)) => {
3161                        mask.mark_read_only(nodes, &mut output);
3162                        bias.mark_read_only(nodes, &mut output);
3163                    }
3164                    (Some(mask), None) => {
3165                        mask.mark_read_only(nodes, &mut output);
3166                    }
3167                    (None, Some(bias)) => {
3168                        bias.mark_read_only(nodes, &mut output);
3169                    }
3170                    (None, None) => {}
3171                };
3172            }
3173            ModuleOperationIr::DeformableConv2dBackward(repr) => {
3174                repr.x.mark_read_only(nodes, &mut output);
3175                repr.weight.mark_read_only(nodes, &mut output);
3176                repr.offset.mark_read_only(nodes, &mut output);
3177                repr.out_grad.mark_read_only(nodes, &mut output);
3178
3179                if let Some(mask) = repr.mask.as_mut() {
3180                    mask.mark_read_only(nodes, &mut output);
3181                }
3182                if let Some(bias) = repr.bias.as_mut() {
3183                    bias.mark_read_only(nodes, &mut output);
3184                }
3185            }
3186            ModuleOperationIr::ConvTranspose1d(repr) => {
3187                repr.x.mark_read_only(nodes, &mut output);
3188                repr.weight.mark_read_only(nodes, &mut output);
3189
3190                if let Some(bias) = &mut repr.bias {
3191                    bias.mark_read_only(nodes, &mut output);
3192                }
3193            }
3194            ModuleOperationIr::ConvTranspose2d(repr) => {
3195                repr.x.mark_read_only(nodes, &mut output);
3196                repr.weight.mark_read_only(nodes, &mut output);
3197
3198                if let Some(bias) = &mut repr.bias {
3199                    bias.mark_read_only(nodes, &mut output);
3200                }
3201            }
3202            ModuleOperationIr::ConvTranspose3d(repr) => {
3203                repr.x.mark_read_only(nodes, &mut output);
3204                repr.weight.mark_read_only(nodes, &mut output);
3205
3206                if let Some(bias) = &mut repr.bias {
3207                    bias.mark_read_only(nodes, &mut output);
3208                }
3209            }
3210            ModuleOperationIr::AvgPool1d(repr) => {
3211                repr.x.mark_read_only(nodes, &mut output);
3212            }
3213            ModuleOperationIr::AvgPool2d(repr) => {
3214                repr.x.mark_read_only(nodes, &mut output);
3215            }
3216            ModuleOperationIr::AvgPool1dBackward(repr) => {
3217                repr.x.mark_read_only(nodes, &mut output);
3218                repr.grad.mark_read_only(nodes, &mut output);
3219            }
3220            ModuleOperationIr::AvgPool2dBackward(repr) => {
3221                repr.x.mark_read_only(nodes, &mut output);
3222                repr.grad.mark_read_only(nodes, &mut output);
3223            }
3224            ModuleOperationIr::AdaptiveAvgPool1d(repr) => {
3225                repr.x.mark_read_only(nodes, &mut output);
3226            }
3227            ModuleOperationIr::AdaptiveAvgPool2d(repr) => {
3228                repr.x.mark_read_only(nodes, &mut output);
3229            }
3230            ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => {
3231                repr.x.mark_read_only(nodes, &mut output);
3232                repr.grad.mark_read_only(nodes, &mut output);
3233            }
3234            ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => {
3235                repr.x.mark_read_only(nodes, &mut output);
3236                repr.grad.mark_read_only(nodes, &mut output);
3237            }
3238            ModuleOperationIr::MaxPool1d(repr) => {
3239                repr.x.mark_read_only(nodes, &mut output);
3240            }
3241            ModuleOperationIr::MaxPool1dWithIndices(repr) => {
3242                repr.x.mark_read_only(nodes, &mut output);
3243            }
3244            ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => {
3245                repr.x.mark_read_only(nodes, &mut output);
3246                repr.grad.mark_read_only(nodes, &mut output);
3247            }
3248            ModuleOperationIr::MaxPool2d(repr) => {
3249                repr.x.mark_read_only(nodes, &mut output);
3250            }
3251            ModuleOperationIr::MaxPool2dWithIndices(repr) => {
3252                repr.x.mark_read_only(nodes, &mut output);
3253            }
3254            ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => {
3255                repr.x.mark_read_only(nodes, &mut output);
3256                repr.grad.mark_read_only(nodes, &mut output);
3257            }
3258            ModuleOperationIr::Interpolate(repr) => {
3259                repr.x.mark_read_only(nodes, &mut output);
3260            }
3261            ModuleOperationIr::InterpolateBackward(repr) => {
3262                repr.x.mark_read_only(nodes, &mut output);
3263                repr.grad.mark_read_only(nodes, &mut output);
3264            }
3265            ModuleOperationIr::Rfft(repr) => {
3266                repr.signal.mark_read_only(nodes, &mut output);
3267            }
3268            ModuleOperationIr::IRfft(repr) => {
3269                repr.input_re.mark_read_only(nodes, &mut output);
3270                repr.input_im.mark_read_only(nodes, &mut output);
3271            }
3272            ModuleOperationIr::Attention(repr) => {
3273                repr.query.mark_read_only(nodes, &mut output);
3274                repr.key.mark_read_only(nodes, &mut output);
3275                repr.value.mark_read_only(nodes, &mut output);
3276                if let Some(mask) = &mut repr.mask {
3277                    mask.mark_read_only(nodes, &mut output);
3278                }
3279                if let Some(attn_bias) = &mut repr.attn_bias {
3280                    attn_bias.mark_read_only(nodes, &mut output);
3281                }
3282            }
3283            ModuleOperationIr::CtcLoss(repr) => {
3284                repr.log_probs.mark_read_only(nodes, &mut output);
3285                repr.targets.mark_read_only(nodes, &mut output);
3286                repr.input_lengths.mark_read_only(nodes, &mut output);
3287                repr.target_lengths.mark_read_only(nodes, &mut output);
3288            }
3289            ModuleOperationIr::CtcLossBackward(repr) => {
3290                repr.log_probs.mark_read_only(nodes, &mut output);
3291                repr.targets.mark_read_only(nodes, &mut output);
3292                repr.input_lengths.mark_read_only(nodes, &mut output);
3293                repr.target_lengths.mark_read_only(nodes, &mut output);
3294                repr.grad_loss.mark_read_only(nodes, &mut output);
3295            }
3296        };
3297
3298        output
3299    }
3300}
3301
3302#[cfg(feature = "distributed")]
3303impl DistributedOperationIr {
3304    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
3305        match self {
3306            DistributedOperationIr::AllReduce(repr) => Box::new([&repr.tensor].into_iter()),
3307        }
3308    }
3309
3310    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
3311        match self {
3312            DistributedOperationIr::AllReduce(repr) => Box::new([&repr.out].into_iter()),
3313        }
3314    }
3315
3316    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
3317        let mut output = Vec::new();
3318
3319        match self {
3320            DistributedOperationIr::AllReduce(repr) => {
3321                repr.tensor.mark_read_only(nodes, &mut output);
3322            }
3323        }
3324
3325        output
3326    }
3327}
3328
3329impl InitOperationIr {
3330    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
3331        Box::new([].into_iter())
3332    }
3333    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
3334        Box::new([&self.out].into_iter())
3335    }
3336}
3337
3338impl TensorIr {
3339    fn mark_read_only(&mut self, nodes: &[TensorId], output: &mut Vec<TensorIr>) {
3340        if self.status == TensorStatus::ReadWrite && nodes.contains(&self.id) {
3341            output.push(self.clone());
3342            self.status = TensorStatus::ReadOnly;
3343        }
3344    }
3345}
3346
3347impl core::hash::Hash for RandomOpIr {
3348    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
3349        self.out.hash(state);
3350
3351        match self.distribution {
3352            Distribution::Default => 1u8.hash(state),
3353            Distribution::Bernoulli(_) => 2u8.hash(state),
3354            Distribution::Uniform(_, _) => 3u8.hash(state),
3355            Distribution::Normal(_, _) => 4u8.hash(state),
3356        }
3357    }
3358}
3359
3360/// Extension trait to extract outputs when registering an operation.
3361pub trait OperationOutput<O> {
3362    /// Extract a single output.
3363    fn output(self) -> O;
3364
3365    /// Extract a fixed number of outputs.
3366    fn outputs<const N: usize>(self) -> [O; N];
3367}
3368
3369impl<O: core::fmt::Debug> OperationOutput<O> for Vec<O> {
3370    fn output(self) -> O {
3371        let [tensor] = self.outputs();
3372        tensor
3373    }
3374
3375    fn outputs<const N: usize>(self) -> [O; N] {
3376        self.try_into().unwrap()
3377    }
3378}