Skip to main content

burn_ir/
operation.rs

1use burn_backend::tensor::IndexingUpdateOp;
2use core::hash::Hash;
3use serde::{Deserialize, Serialize};
4
5use alloc::borrow::ToOwned;
6use alloc::boxed::Box;
7use alloc::{string::String, vec::Vec};
8
9use burn_backend::{
10    DType, Distribution, Slice,
11    ops::{
12        ConvOptions, ConvTransposeOptions, DeformConvOptions, GridSampleOptions,
13        GridSamplePaddingMode, InterpolateMode, InterpolateOptions,
14    },
15    quantization::QuantScheme,
16};
17
18use crate::{ScalarIr, TensorId, TensorIr, TensorStatus};
19
20/// Custom operation in fusion stream, declaring its inputs and outputs.
21#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
22pub struct CustomOpIr {
23    /// Unique identifier of the operation.
24    pub id: String,
25    /// Input tensors used in the custom operation.
26    pub inputs: Vec<TensorIr>,
27    /// Output tensors used in the custom operation.
28    pub outputs: Vec<TensorIr>,
29}
30
31impl CustomOpIr {
32    /// Create a new custom operation intermediate representation.
33    pub fn new(id: &'static str, inputs: &[TensorIr], outputs: &[TensorIr]) -> Self {
34        Self {
35            id: id.to_owned(),
36            inputs: inputs.to_vec(),
37            outputs: outputs.to_vec(),
38        }
39    }
40
41    /// Cast the intermediate representation, and get the in and output tensors.
42    pub fn as_fixed<const N_IN: usize, const N_OUT: usize>(
43        &self,
44    ) -> (&[TensorIr; N_IN], &[TensorIr; N_OUT]) {
45        (
46            self.inputs.as_slice().try_into().expect(
47                "Wrong number of inputs expected (expected {D}, is {}), check your implementation",
48            ),
49            self.outputs.as_slice().try_into().expect(
50                "Wrong number of outputs expected (expected {D}, is {}), check your implementation",
51            ),
52        )
53    }
54
55    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
56        Box::new(self.inputs.iter())
57    }
58
59    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
60        Box::new(self.outputs.iter())
61    }
62}
63
64/// Describe all tensor operations possible.
65#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
66pub enum OperationIr {
67    /// Basic operation on a float tensor.
68    BaseFloat(BaseOperationIr),
69    /// Basic operation on an int tensor.
70    BaseInt(BaseOperationIr),
71    /// Basic operation on a bool tensor.
72    BaseBool(BaseOperationIr),
73    /// Numeric operation on a float tensor.
74    NumericFloat(DType, NumericOperationIr),
75    /// Numeric operation on an int tensor.
76    NumericInt(DType, NumericOperationIr),
77    /// Operation specific to a bool tensor.
78    Bool(BoolOperationIr),
79    /// Operation specific to an int tensor.
80    Int(IntOperationIr),
81    /// Operation specific to a float tensor.
82    Float(DType, FloatOperationIr),
83    /// Module operation.
84    Module(ModuleOperationIr),
85    /// Initialize operation.
86    Init(InitOperationIr),
87    /// A custom operation.
88    Custom(CustomOpIr),
89    /// A tensor is dropped.
90    Drop(TensorIr),
91}
92
93/// Operation intermediate representation specific to a float tensor.
94#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
95pub enum FloatOperationIr {
96    /// Operation corresponding to [exp](burn_backend::ops::FloatTensorOps::float_exp).
97    Exp(UnaryOpIr),
98    /// Operation corresponding to [log](burn_backend::ops::FloatTensorOps::float_log).
99    Log(UnaryOpIr),
100    /// Operation corresponding to [log1p](burn_backend::ops::FloatTensorOps::float_log1p).
101    Log1p(UnaryOpIr),
102    /// Operation corresponding to [erf](burn_backend::ops::FloatTensorOps::float_erf).
103    Erf(UnaryOpIr),
104    /// Operation corresponding to [powf_scalar](burn_backend::ops::FloatTensorOps::float_powf_scalar).
105    PowfScalar(ScalarOpIr),
106    /// Operation corresponding to [sqrt](burn_backend::ops::FloatTensorOps::float_sqrt).
107    Sqrt(UnaryOpIr),
108    /// Operation corresponding to [cos](burn_backend::ops::FloatTensorOps::float_cos).
109    Cos(UnaryOpIr),
110    /// Operation corresponding to [cosh](burn_backend::ops::FloatTensorOps::float_cosh).
111    Cosh(UnaryOpIr),
112    /// Operation corresponding to [sin](burn_backend::ops::FloatTensorOps::float_sin).
113    Sin(UnaryOpIr),
114    /// Operation corresponding to [sin](burn_backend::ops::FloatTensorOps::float_sinh).
115    Sinh(UnaryOpIr),
116    /// Operation corresponding to [tan](burn_backend::ops::FloatTensorOps::float_tan).
117    Tan(UnaryOpIr),
118    /// Operation corresponding to [tanh](burn_backend::ops::FloatTensorOps::float_tanh).
119    Tanh(UnaryOpIr),
120    /// Operation corresponding to [acos](burn_backend::ops::FloatTensorOps::float_acos).
121    ArcCos(UnaryOpIr),
122    /// Operation corresponding to [acosh](burn_backend::ops::FloatTensorOps::float_acosh).
123    ArcCosh(UnaryOpIr),
124    /// Operation corresponding to [asin](burn_backend::ops::FloatTensorOps::float_asin).
125    ArcSin(UnaryOpIr),
126    /// Operation corresponding to [asinh](burn_backend::ops::FloatTensorOps::float_asinh).
127    ArcSinh(UnaryOpIr),
128    /// Operation corresponding to [atan](burn_backend::ops::FloatTensorOps::float_atan).
129    ArcTan(UnaryOpIr),
130    /// Operation corresponding to [atanh](burn_backend::ops::FloatTensorOps::float_atanh).
131    ArcTanh(UnaryOpIr),
132    /// Operation corresponding to [atan2](burn_backend::ops::FloatTensorOps::float_atan2).
133    ArcTan2(BinaryOpIr),
134    /// Operation corresponding to [round](burn_backend::ops::FloatTensorOps::float_round).
135    Round(UnaryOpIr),
136    /// Operation corresponding to [floor](burn_backend::ops::FloatTensorOps::float_floor).
137    Floor(UnaryOpIr),
138    /// Operation corresponding to [ceil](burn_backend::ops::FloatTensorOps::float_ceil).
139    Ceil(UnaryOpIr),
140    /// Operation corresponding to [trunc](burn_backend::ops::FloatTensorOps::float_trunc).
141    Trunc(UnaryOpIr),
142    /// Operation corresponding to [into_int](burn_backend::ops::FloatTensorOps::float_into_int).
143    IntoInt(CastOpIr),
144    /// Operation corresponding to [matmul](burn_backend::ops::FloatTensorOps::float_matmul).
145    Matmul(MatmulOpIr),
146    /// Operation corresponding to [cross](burn_backend::ops::FloatTensorOps::float_cross).
147    Cross(CrossOpIr),
148    /// Operation corresponding to [random](burn_backend::ops::FloatTensorOps::float_random).
149    Random(RandomOpIr),
150    /// Operation corresponding to [recip](burn_backend::ops::FloatTensorOps::float_recip).
151    Recip(UnaryOpIr),
152    /// Operation corresponding to [is_nan](burn_backend::ops::FloatTensorOps::float_is_nan).
153    IsNan(UnaryOpIr),
154    /// Operation corresponding to [is_nan](burn_backend::ops::FloatTensorOps::float_is_inf).
155    IsInf(UnaryOpIr),
156    /// Operation corresponding to [quantize](burn_backend::ops::QTensorOps::quantize).
157    Quantize(QuantizeOpIr),
158    /// Operation corresponding to [dequantize](burn_backend::ops::QTensorOps::dequantize).
159    Dequantize(DequantizeOpIr),
160    /// Operation corresponding to [grid_sample_2d](burn_backend::ops::FloatTensorOps::float_grid_sample_2d).
161    GridSample2d(GridSample2dOpIr),
162}
163
164/// Operation intermediate representation specific to module.
165#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
166pub enum ModuleOperationIr {
167    /// Operation corresponding to [embedding](burn_backend::ops::ModuleOps::embedding).
168    Embedding(EmbeddingOpIr),
169    /// Operation corresponding to [embedding_backward](burn_backend::ops::ModuleOps::embedding_backward).
170    EmbeddingBackward(EmbeddingBackwardOpIr),
171    /// Operation corresponding to [conv1d](burn_backend::ops::ModuleOps::conv1d).
172    Conv1d(Conv1dOpIr),
173    /// Operation corresponding to [conv2d](burn_backend::ops::ModuleOps::conv2d).
174    Conv2d(Conv2dOpIr),
175    /// Operation corresponding to [conv3d](burn_backend::ops::ModuleOps::conv3d).
176    Conv3d(Conv3dOpIr),
177    /// Operation corresponding to [deform_conv2d](burn_backend::ops::ModuleOps::deform_conv2d)
178    DeformableConv2d(Box<DeformConv2dOpIr>),
179    /// Operation corresponding to [deform_conv2d_backward](burn_backend::ops::ModuleOps::deform_conv2d_backward)
180    DeformableConv2dBackward(Box<DeformConv2dBackwardOpIr>),
181    /// Operation corresponding to [conv transpose 1d](burn_backend::ops::ModuleOps::conv_transpose1d).
182    ConvTranspose1d(ConvTranspose1dOpIr),
183    /// Operation corresponding to [conv transpose 2d](burn_backend::ops::ModuleOps::conv_transpose2d).
184    ConvTranspose2d(ConvTranspose2dOpIr),
185    /// Operation corresponding to [conv transpose 3d](burn_backend::ops::ModuleOps::conv_transpose3d).
186    ConvTranspose3d(ConvTranspose3dOpIr),
187    /// Operation corresponding to [avg pool 1d](burn_backend::ops::ModuleOps::avg_pool1d).
188    AvgPool1d(AvgPool1dOpIr),
189    /// Operation corresponding to [avg pool 2d](burn_backend::ops::ModuleOps::avg_pool2d).
190    AvgPool2d(AvgPool2dOpIr),
191    /// Operation corresponding to
192    /// [avg pool 1d backward](burn_backend::ops::ModuleOps::avg_pool1d_backward).
193    AvgPool1dBackward(AvgPool1dBackwardOpIr),
194    /// Operation corresponding to
195    /// [avg pool 2d backward](burn_backend::ops::ModuleOps::avg_pool2d_backward).
196    AvgPool2dBackward(AvgPool2dBackwardOpIr),
197    /// Operation corresponding to
198    /// [adaptive avg pool 1d](burn_backend::ops::ModuleOps::adaptive_avg_pool1d).
199    AdaptiveAvgPool1d(AdaptiveAvgPool1dOpIr),
200    /// Operation corresponding to
201    /// [adaptive avg pool 2d](burn_backend::ops::ModuleOps::adaptive_avg_pool2d).
202    AdaptiveAvgPool2d(AdaptiveAvgPool2dOpIr),
203    /// Operation corresponding to
204    /// [adaptive avg pool 1d backward](burn_backend::ops::ModuleOps::adaptive_avg_pool1d_backward).
205    AdaptiveAvgPool1dBackward(AdaptiveAvgPool1dBackwardOpIr),
206    /// Operation corresponding to
207    /// [adaptive avg pool 2d backward](burn_backend::ops::ModuleOps::adaptive_avg_pool2d_backward).
208    AdaptiveAvgPool2dBackward(AdaptiveAvgPool2dBackwardOpIr),
209    /// Operation corresponding to
210    /// [max pool 1d](burn_backend::ops::ModuleOps::max_pool1d).
211    MaxPool1d(MaxPool1dOpIr),
212    /// Operation corresponding to
213    /// [max pool 1d with indices](burn_backend::ops::ModuleOps::max_pool1d_with_indices).
214    MaxPool1dWithIndices(MaxPool1dWithIndicesOpIr),
215    /// Operation corresponding to
216    /// [max pool 1d with indices backward](burn_backend::ops::ModuleOps::max_pool1d_with_indices_backward).
217    MaxPool1dWithIndicesBackward(MaxPool1dWithIndicesBackwardOpIr),
218    /// Operation corresponding to
219    /// [max pool 2d](burn_backend::ops::ModuleOps::max_pool1d).
220    MaxPool2d(MaxPool2dOpIr),
221    /// Operation corresponding to
222    /// [max pool 2d with indices](burn_backend::ops::ModuleOps::max_pool2d_with_indices).
223    MaxPool2dWithIndices(MaxPool2dWithIndicesOpIr),
224    /// Operation corresponding to
225    /// [max pool 2d with indices backward](burn_backend::ops::ModuleOps::max_pool2d_with_indices_backward).
226    MaxPool2dWithIndicesBackward(MaxPool2dWithIndicesBackwardOpIr),
227    /// Operation corresponding to [interpolate](burn_backend::ops::ModuleOps::interpolate).
228    Interpolate(InterpolateOpIr),
229    /// Operation corresponding to [interpolate backward](burn_backend::ops::ModuleOps::interpolate_backward).
230    InterpolateBackward(InterpolateBackwardOpIr),
231}
232
233/// Basic operations that can be done on any tensor type.
234#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
235pub enum BaseOperationIr {
236    /// Operation corresponding to:
237    ///
238    /// Float => [reshape](burn_backend::ops::FloatTensorOps::float_reshape).
239    /// Int => [reshape](burn_backend::ops::IntTensorOps::int_reshape).
240    /// Bool => [reshape](burn_backend::ops::BoolTensorOps::bool_reshape).
241    Reshape(ShapeOpIr),
242
243    /// Operation corresponding to:
244    ///
245    /// Float => [swap_dims](burn_backend::ops::FloatTensorOps::float_swap_dims).
246    /// Int => [swap_dims](burn_backend::ops::IntTensorOps::int_swap_dims).
247    /// Bool => [swap_dims](burn_backend::ops::BoolTensorOps::bool_swap_dims).
248    SwapDims(SwapDimsOpIr),
249
250    /// Operation corresponding to:
251    ///
252    /// Float => [permute](burn_backend::ops::FloatTensorOps::float_permute).
253    /// Int => [permute](burn_backend::ops::IntTensorOps::int_permute).
254    /// Bool => [permute](burn_backend::ops::BoolTensorOps::bool_permute).
255    Permute(PermuteOpIr),
256
257    /// Operation corresponding to:
258    /// Float => [flip](burn_backend::ops::FloatTensorOps::float_flip).
259    /// Int => [flip](burn_backend::ops::IntTensorOps::int_flip).
260    /// Bool => [flip](burn_backend::ops::BoolTensorOps::bool_flip).
261    Flip(FlipOpIr),
262
263    /// Operation corresponding to:
264    ///
265    /// Float => [expand](burn_backend::ops::FloatTensorOps::float_expand).
266    /// Int => [expand](burn_backend::ops::IntTensorOps::int_expand).
267    /// Bool => [expand](burn_backend::ops::BoolTensorOps::bool_expand).
268    Expand(ShapeOpIr),
269
270    /// Unfold windows along an axis.
271    ///
272    Unfold(UnfoldOpIr),
273
274    /// Operation corresponding to:
275    ///
276    /// Float => [slice](burn_backend::ops::FloatTensorOps::float_slice).
277    /// Int => [slice](burn_backend::ops::IntTensorOps::int_slice).
278    /// Bool => [slice](burn_backend::ops::BoolTensorOps::bool_slice).
279    Slice(SliceOpIr),
280    /// Operation corresponding to:
281    ///
282    /// Float => [slice assign](burn_backend::ops::FloatTensorOps::float_slice_assign).
283    /// Int => [slice assign](burn_backend::ops::IntTensorOps::int_slice_assign).
284    /// Bool => [slice assign](burn_backend::ops::BoolTensorOps::bool_slice_assign).
285    SliceAssign(SliceAssignOpIr),
286    /// Operation corresponding to:
287    ///
288    /// Float => [select](burn_backend::ops::FloatTensorOps::float_select).
289    /// Int => [select](burn_backend::ops::IntTensorOps::int_select).
290    /// Bool => [select](burn_backend::ops::BoolTensorOps::bool_select).
291    Select(SelectOpIr),
292    /// Operation corresponding to:
293    ///
294    /// Float => [select assign](burn_backend::ops::FloatTensorOps::float_select_add).
295    /// Int => [select assign](burn_backend::ops::IntTensorOps::int_select_add).
296    /// Bool => [select assign](burn_backend::ops::BoolTensorOps::bool_select_or).
297    SelectAssign(SelectAssignOpIr),
298    /// Operation corresponding to:
299    ///
300    /// Float => [mask where](burn_backend::ops::FloatTensorOps::float_mask_where).
301    /// Int => [mask where](burn_backend::ops::IntTensorOps::int_mask_where).
302    /// Bool => [mask where](burn_backend::ops::BoolTensorOps::bool_mask_where).
303    MaskWhere(MaskWhereOpIr),
304    /// Operation corresponding to:
305    ///
306    /// Float => [mask fill](burn_backend::ops::FloatTensorOps::float_mask_fill).
307    /// Int => [mask fill](burn_backend::ops::IntTensorOps::int_mask_fill).
308    /// Bool => [mask fill](burn_backend::ops::BoolTensorOps::bool_mask_fill).
309    MaskFill(MaskFillOpIr),
310    /// Operation corresponding to:
311    ///
312    /// Float => [gather](burn_backend::ops::FloatTensorOps::float_gather).
313    /// Int => [gather](burn_backend::ops::IntTensorOps::int_gather).
314    /// Bool => [gather](burn_backend::ops::BoolTensorOps::bool_gather).
315    Gather(GatherOpIr),
316    /// Operation corresponding to:
317    ///
318    /// Float => [scatter](burn_backend::ops::FloatTensorOps::float_scatter_add).
319    /// Int => [scatter](burn_backend::ops::IntTensorOps::int_scatter_add).
320    /// Bool => [scatter](burn_backend::ops::BoolTensorOps::bool_scatter_or).
321    Scatter(ScatterOpIr),
322    /// Operation corresponding to:
323    ///
324    /// Float => [equal](burn_backend::ops::FloatTensorOps::float_equal).
325    /// Int => [equal](burn_backend::ops::IntTensorOps::int_equal).
326    /// Bool => [equal](burn_backend::ops::BoolTensorOps::bool_equal).
327    Equal(BinaryOpIr),
328    /// Operation corresponding to:
329    ///
330    /// Float => [equal elem](burn_backend::ops::FloatTensorOps::float_equal_elem).
331    /// Int => [equal elem](burn_backend::ops::IntTensorOps::int_equal_elem).
332    /// Bool => [equal elem](burn_backend::ops::BoolTensorOps::bool_equal_elem).
333    EqualElem(ScalarOpIr),
334    /// Operation corresponding to:
335    ///
336    /// Float => [repeat dim](burn_backend::ops::FloatTensorOps::float_repeat_dim).
337    /// Int => [repeat dim](burn_backend::ops::IntTensorOps::int_repeat_dim).
338    /// Bool => [repeat dim](burn_backend::ops::BoolTensorOps::bool_repeat_dim).
339    RepeatDim(RepeatDimOpIr),
340    /// Operation corresponding to:
341    ///
342    /// Float => [cat](burn_backend::ops::FloatTensorOps::float_cat).
343    /// Int => [cat](burn_backend::ops::IntTensorOps::int_cat).
344    /// Bool => [cat](burn_backend::ops::BoolTensorOps::bool_cat).
345    Cat(CatOpIr),
346    /// Cast operation, no direct operation and should be supported by fusion backend.
347    Cast(CastOpIr),
348    /// Operation corresponding to:
349    ///
350    /// Float => [empty](burn_backend::ops::FloatTensorOps::float_empty).
351    /// Int => [empty](burn_backend::ops::IntTensorOps::int_empty).
352    /// Bool => [empty](burn_backend::ops::BoolTensorOps::bool_empty).
353    Empty(CreationOpIr),
354    /// Operation corresponding to:
355    ///
356    /// Float => [ones](burn_backend::ops::FloatTensorOps::float_ones).
357    /// Int => [ones](burn_backend::ops::IntTensorOps::int_ones).
358    /// Bool => [ones](burn_backend::ops::BoolTensorOps::bool_ones).
359    Ones(CreationOpIr),
360    /// Operation corresponding to:
361    ///
362    /// Float => [zeros](burn_backend::ops::FloatTensorOps::float_zeros).
363    /// Int => [zeros](burn_backend::ops::IntTensorOps::int_zeros).
364    /// Bool => [zeros](burn_backend::ops::BoolTensorOps::bool_zeros).
365    Zeros(CreationOpIr),
366}
367
368/// Numeric operations on int and float tensors.
369#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
370pub enum NumericOperationIr {
371    /// Operation corresponding to:
372    ///
373    /// Float => [add](burn_backend::ops::FloatTensorOps::float_add).
374    /// Int => [add](burn_backend::ops::IntTensorOps::int_add).
375    Add(BinaryOpIr),
376    /// Operation corresponding to:
377    ///
378    /// Float => [add scalar](burn_backend::ops::FloatTensorOps::float_add_scalar).
379    /// Int => [add scalar](burn_backend::ops::IntTensorOps::int_add_scalar).
380    AddScalar(ScalarOpIr),
381    /// Operation corresponding to:
382    ///
383    /// Float => [sub](burn_backend::ops::FloatTensorOps::float_sub).
384    /// Int => [sub](burn_backend::ops::IntTensorOps::int_sub).
385    Sub(BinaryOpIr),
386    /// Operation corresponding to:
387    ///
388    /// Float => [sub scalar](burn_backend::ops::FloatTensorOps::float_sub_scalar).
389    /// Int => [sub scalar](burn_backend::ops::IntTensorOps::int_sub_scalar).
390    SubScalar(ScalarOpIr),
391    /// Operation corresponding to:
392    ///
393    /// Float => [div](burn_backend::ops::FloatTensorOps::float_div).
394    /// Int => [div](burn_backend::ops::IntTensorOps::int_div).
395    Div(BinaryOpIr),
396    /// Operation corresponding to:
397    ///
398    /// Float => [div scalar](burn_backend::ops::FloatTensorOps::float_div_scalar).
399    /// Int => [div scalar](burn_backend::ops::IntTensorOps::int_div_scalar).
400    DivScalar(ScalarOpIr),
401    /// Operation corresponding to:
402    ///
403    /// Float => [rem](burn_backend::ops::FloatTensorOps::float_remainder).
404    /// Int => [rem](burn_backend::ops::IntTensorOps::int_remainder).
405    Rem(BinaryOpIr),
406    /// Operation corresponding to:
407    ///
408    /// Float => [rem scalar](burn_backend::ops::FloatTensorOps::float_remainder_scalar).
409    /// Int => [rem scalar](burn_backend::ops::IntTensorOps::int_remainder_scalar).
410    RemScalar(ScalarOpIr),
411    /// Operation corresponding to:
412    ///
413    /// Float => [mul](burn_backend::ops::FloatTensorOps::float_mul).
414    /// Int => [mul](burn_backend::ops::IntTensorOps::int_mul).
415    Mul(BinaryOpIr),
416    /// Operation corresponding to:
417    ///
418    /// Float => [mul scalar](burn_backend::ops::FloatTensorOps::float_mul_scalar).
419    /// Int => [mul scalar](burn_backend::ops::IntTensorOps::int_mul_scalar).
420    MulScalar(ScalarOpIr),
421    /// Operation corresponding to:
422    ///
423    /// Float => [abs](burn_backend::ops::FloatTensorOps::float_abs).
424    /// Int => [abs](burn_backend::ops::IntTensorOps::int_abs).
425    Abs(UnaryOpIr),
426    /// Operation corresponding to:
427    ///
428    /// Float => [full](burn_backend::ops::FloatTensorOps::float_full).
429    /// Int => [full](burn_backend::ops::IntTensorOps::int_full).
430    Full(FullOpIr),
431    /// Operation corresponding to:
432    ///
433    /// Float => [mean dim](burn_backend::ops::FloatTensorOps::float_mean_dim).
434    /// Int => [mean dim](burn_backend::ops::IntTensorOps::int_mean_dim).
435    MeanDim(ReduceDimOpIr),
436    /// Operation corresponding to:
437    ///
438    /// Float => [mean](burn_backend::ops::FloatTensorOps::float_mean).
439    /// Int => [mean](burn_backend::ops::IntTensorOps::int_mean).
440    Mean(ReduceOpIr),
441    /// Operation corresponding to:
442    ///
443    /// Float => [sum](burn_backend::ops::FloatTensorOps::float_sum).
444    /// Int => [sum](burn_backend::ops::IntTensorOps::int_sum).
445    Sum(ReduceOpIr),
446    /// Operation corresponding to:
447    ///
448    /// Float => [sum dim](burn_backend::ops::FloatTensorOps::float_sum_dim).
449    /// Int => [sum dim](burn_backend::ops::IntTensorOps::int_sum_dim).
450    SumDim(ReduceDimOpIr),
451    /// Operation corresponding to:
452    ///
453    /// Float => [prod](burn_backend::ops::FloatTensorOps::float_prod).
454    /// Int => [prod](burn_backend::ops::IntTensorOps::int_prod).
455    Prod(ReduceOpIr),
456    /// Operation corresponding to:
457    ///
458    /// Float => [prod dim](burn_backend::ops::FloatTensorOps::float_prod_dim).
459    /// Int => [prod dim](burn_backend::ops::IntTensorOps::int_prod_dim).
460    ProdDim(ReduceDimOpIr),
461    /// Operation corresponding to:
462    ///
463    /// Float => [greater](burn_backend::ops::FloatTensorOps::float_greater).
464    /// Int => [greater](burn_backend::ops::IntTensorOps::int_greater).
465    Greater(BinaryOpIr),
466    /// Operation corresponding to:
467    ///
468    /// Float => [greater elem](burn_backend::ops::FloatTensorOps::float_greater_elem).
469    /// Int => [greater elem](burn_backend::ops::IntTensorOps::int_greater_elem).
470    GreaterElem(ScalarOpIr),
471    /// Operation corresponding to:
472    ///
473    /// Float => [greater equal](burn_backend::ops::FloatTensorOps::float_greater_elem).
474    /// Int => [greater elem](burn_backend::ops::IntTensorOps::int_greater_elem).
475    GreaterEqual(BinaryOpIr),
476    /// Operation corresponding to:
477    ///
478    /// Float => [greater equal elem](burn_backend::ops::FloatTensorOps::float_greater_equal_elem).
479    /// Int => [greater equal elem](burn_backend::ops::IntTensorOps::int_greater_equal_elem).
480    GreaterEqualElem(ScalarOpIr),
481    /// Operation corresponding to:
482    ///
483    /// Float => [lower](burn_backend::ops::FloatTensorOps::float_lower).
484    /// Int => [lower](burn_backend::ops::IntTensorOps::int_lower).
485    Lower(BinaryOpIr),
486    /// Operation corresponding to:
487    ///
488    /// Float => [lower elem](burn_backend::ops::FloatTensorOps::float_lower_elem).
489    /// Int => [lower elem](burn_backend::ops::IntTensorOps::int_lower_elem).
490    LowerElem(ScalarOpIr),
491    /// Operation corresponding to:
492    ///
493    /// Float => [lower equal](burn_backend::ops::FloatTensorOps::float_lower_equal).
494    /// Int => [lower equal](burn_backend::ops::IntTensorOps::int_lower_equal).
495    LowerEqual(BinaryOpIr),
496    /// Operation corresponding to:
497    ///
498    /// Float => [lower equal elem](burn_backend::ops::FloatTensorOps::float_lower_equal_elem).
499    /// Int => [lower equal elem](burn_backend::ops::IntTensorOps::int_lower_equal_elem).
500    LowerEqualElem(ScalarOpIr),
501    /// Operation corresponding to:
502    ///
503    /// Float => [argmax](burn_backend::ops::FloatTensorOps::float_argmax).
504    /// Int => [argmax](burn_backend::ops::IntTensorOps::int_argmax).
505    ArgMax(ReduceDimOpIr),
506    /// Operation corresponding to:
507    ///
508    /// Float => [argmin](burn_backend::ops::FloatTensorOps::float_argmin).
509    /// Int => [argmin](burn_backend::ops::IntTensorOps::int_argmin).
510    ArgMin(ReduceDimOpIr),
511    /// Operation corresponding to:
512    ///
513    /// Float => [max](burn_backend::ops::FloatTensorOps::float_max).
514    /// Int => [max](burn_backend::ops::IntTensorOps::int_max).
515    Max(ReduceOpIr),
516    /// Operation corresponding to:
517    ///
518    /// Float => [max dim with indices](burn_backend::ops::FloatTensorOps::float_max_dim_with_indices).
519    /// Int => [max dim with indices](burn_backend::ops::IntTensorOps::int_max_dim_with_indices).
520    MaxDimWithIndices(ReduceDimWithIndicesOpIr),
521    /// Operation corresponding to:
522    ///
523    /// Float => [min dim with indices](burn_backend::ops::FloatTensorOps::float_min_dim_with_indices).
524    /// Int => [min dim with indices](burn_backend::ops::IntTensorOps::int_min_dim_with_indices).
525    MinDimWithIndices(ReduceDimWithIndicesOpIr),
526    /// Operation corresponding to:
527    ///
528    /// Float => [min](burn_backend::ops::FloatTensorOps::float_min).
529    /// Int => [min](burn_backend::ops::IntTensorOps::int_min).
530    Min(ReduceOpIr),
531    /// Operation corresponding to:
532    ///
533    /// Float => [max dim](burn_backend::ops::FloatTensorOps::float_max_dim).
534    /// Int => [max dim](burn_backend::ops::IntTensorOps::int_max_dim).
535    MaxDim(ReduceDimOpIr),
536    /// Operation corresponding to:
537    ///
538    /// Float => [min dim](burn_backend::ops::FloatTensorOps::float_min_dim).
539    /// Int => [min dim](burn_backend::ops::IntTensorOps::int_min_dim).
540    MinDim(ReduceDimOpIr),
541    /// Operation corresponding to:
542    ///
543    /// Float => [max_abs](burn_backend::ops::FloatTensorOps::float_max_abs).
544    /// Int => [max_abs](burn_backend::ops::IntTensorOps::int_max_abs).
545    MaxAbs(ReduceOpIr),
546    /// Operation corresponding to:
547    ///
548    /// Float => [max_abs dim](burn_backend::ops::FloatTensorOps::float_max_abs_dim).
549    /// Int => [max_abs dim](burn_backend::ops::IntTensorOps::int_max_abs_dim).
550    MaxAbsDim(ReduceDimOpIr),
551    /// Operation corresponding to:
552    ///
553    /// Float => [clamp](burn_backend::ops::FloatTensorOps::float_clamp).
554    /// Int => [clamp](burn_backend::ops::IntTensorOps::int_clamp).
555    Clamp(ClampOpIr),
556    /// Operation corresponding to:
557    ///
558    /// Int => [random](burn_backend::ops::IntTensorOps::int_random).
559    IntRandom(RandomOpIr),
560    /// Operation corresponding to:
561    ///
562    /// Float => [powf](burn_backend::ops::FloatTensorOps::float_powf).
563    /// Int => [powf](burn_backend::ops::IntTensorOps::int_powf).
564    Powf(BinaryOpIr),
565    /// Operation corresponding to:
566    ///
567    /// Float => [cumsum](burn_backend::ops::FloatTensorOps::float_cumsum).
568    /// Int => [cumsum](burn_backend::ops::IntTensorOps::int_cumsum).
569    CumSum(DimOpIr),
570    /// Operation corresponding to:
571    ///
572    /// Float => [cumprod](burn_backend::ops::FloatTensorOps::float_cumprod).
573    /// Int => [cumprod](burn_backend::ops::IntTensorOps::int_cumprod).
574    CumProd(DimOpIr),
575    /// Operation corresponding to:
576    ///
577    /// Float => [cummin](burn_backend::ops::FloatTensorOps::float_cummin).
578    /// Int => [cummin](burn_backend::ops::IntTensorOps::int_cummin).
579    CumMin(DimOpIr),
580    /// Operation corresponding to:
581    ///
582    /// Float => [cummax](burn_backend::ops::FloatTensorOps::float_cummax).
583    /// Int => [cummax](burn_backend::ops::IntTensorOps::int_cummax).
584    CumMax(DimOpIr),
585}
586
587/// Operation intermediate representation specific to an int tensor.
588#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
589pub enum IntOperationIr {
590    /// Operation corresponding to [into float](burn_backend::ops::IntTensorOps::int_into_float).
591    IntoFloat(CastOpIr),
592    /// Operation corresponding to:
593    ///
594    /// Int => [bitwise and](burn_backend::ops::IntTensorOps::bitwise_and).
595    BitwiseAnd(BinaryOpIr),
596    /// Operation corresponding to:
597    ///
598    /// Int => [bitwise and scalar](burn_backend::ops::IntTensorOps::bitwise_and_scalar).
599    BitwiseAndScalar(ScalarOpIr),
600    /// Operation corresponding to:
601    ///
602    /// Int => [bitwise or](burn_backend::ops::IntTensorOps::bitwise_or).
603    BitwiseOr(BinaryOpIr),
604    /// Operation corresponding to:
605    ///
606    /// Int => [bitwise or scalar](burn_backend::ops::IntTensorOps::bitwise_or_scalar).
607    BitwiseOrScalar(ScalarOpIr),
608    /// Operation corresponding to:
609    ///
610    /// Int => [bitwise xor](burn_backend::ops::IntTensorOps::bitwise_xor).
611    BitwiseXor(BinaryOpIr),
612    /// Operation corresponding to:
613    ///
614    /// Int => [bitwise xor scalar](burn_backend::ops::IntTensorOps::bitwise_xor_scalar).
615    BitwiseXorScalar(ScalarOpIr),
616    /// Operation corresponding to:
617    ///
618    /// Int => [bitwise not](burn_backend::ops::IntTensorOps::bitwise_not).
619    BitwiseNot(UnaryOpIr),
620    /// Operation corresponding to:
621    ///
622    /// Int => [bitwise left shift](burn_backend::ops::IntTensorOps::bitwise_left_shift).
623    BitwiseLeftShift(BinaryOpIr),
624    /// Operation corresponding to:
625    ///
626    /// Int => [bitwise left shift scalar](burn_backend::ops::IntTensorOps::bitwise_left_shift_scalar).
627    BitwiseLeftShiftScalar(ScalarOpIr),
628    /// Operation corresponding to:
629    ///
630    /// Int => [bitwise right shift](burn_backend::ops::IntTensorOps::bitwise_right_shift).
631    BitwiseRightShift(BinaryOpIr),
632    /// Operation corresponding to:
633    ///
634    /// Int => [bitwise right shift scalar](burn_backend::ops::IntTensorOps::bitwise_right_shift_scalar).
635    BitwiseRightShiftScalar(ScalarOpIr),
636    /// Operation corresponding to [matmul](burn_backend::ops::IntTensorOps::int_matmul).
637    Matmul(MatmulOpIr),
638}
639
640/// Operation intermediate representation specific to a bool tensor.
641#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
642pub enum BoolOperationIr {
643    /// Operation corresponding to [into float](burn_backend::ops::BoolTensorOps::bool_into_float).
644    IntoFloat(CastOpIr),
645    /// Operation corresponding to [into int](burn_backend::ops::BoolTensorOps::bool_into_int).
646    IntoInt(CastOpIr),
647    /// Operation corresponding to [not](burn_backend::ops::BoolTensorOps::bool_not).
648    Not(UnaryOpIr),
649    /// Operation corresponding to [and](burn_backend::ops::BoolTensorOps::bool_and).
650    And(BinaryOpIr),
651    /// Operation corresponding to [or](burn_backend::ops::BoolTensorOps::bool_or).
652    Or(BinaryOpIr),
653}
654
655/// Swap dim operation intermediate representation.
656#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
657pub struct SwapDimsOpIr {
658    /// Input tensor intermediate representation.
659    pub input: TensorIr,
660    /// Output tensor intermediate representation.
661    pub out: TensorIr,
662    /// The first dim to swap.
663    pub dim1: usize,
664    /// The second dim to swap.
665    pub dim2: usize,
666}
667
668/// Permute operation intermediate representation.
669#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
670pub struct PermuteOpIr {
671    /// Input tensor intermediate representation.
672    pub input: TensorIr,
673    /// Output tensor intermediate representation.
674    pub out: TensorIr,
675    /// The new order of the dimensions.
676    pub axes: Vec<usize>,
677}
678
679/// Shape operation intermediate representation.
680#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
681pub struct ShapeOpIr {
682    /// Input tensor intermediate representation.
683    pub input: TensorIr,
684    /// Output tensor intermediate representation with the new shape.
685    pub out: TensorIr,
686}
687
688/// Unfold operation intermediate representation.
689#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
690pub struct UnfoldOpIr {
691    /// Input tensor intermediate representation.
692    pub input: TensorIr,
693    /// Output tensor intermediate representation.
694    pub out: TensorIr,
695
696    /// The selected dim.
697    pub dim: usize,
698    /// The window size.
699    pub size: usize,
700    /// The window step along dim.
701    pub step: usize,
702}
703
704/// Flip operation intermediate representation.
705#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
706pub struct FlipOpIr {
707    /// Input tensor intermediate representation.
708    pub input: TensorIr,
709    /// Output tensor intermediate representation.
710    pub out: TensorIr,
711    /// The dimensions to flip.
712    pub axes: Vec<usize>,
713}
714
715#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
716#[allow(missing_docs)]
717pub struct RandomOpIr {
718    pub out: TensorIr,
719    pub distribution: Distribution,
720}
721
722/// Creation operation intermediate representation.
723/// As opposed to [InitOperationIr], creation operations are lazy initialized.
724#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
725pub struct CreationOpIr {
726    /// Output tensor intermediate representation.
727    pub out: TensorIr,
728}
729
730/// Full operation intermediate representation.
731#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
732pub struct FullOpIr {
733    /// Output tensor intermediate representation.
734    pub out: TensorIr,
735    /// Fill value.
736    pub value: ScalarIr,
737}
738
739#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
740/// Declares a tensor has been initialized.
741///
742/// It is necessary to register for proper orphan detection and avoid memory leak.
743pub struct InitOperationIr {
744    /// The initialized tensor.
745    pub out: TensorIr,
746}
747
748#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
749#[allow(missing_docs)]
750pub struct BinaryOpIr {
751    pub lhs: TensorIr,
752    pub rhs: TensorIr,
753    pub out: TensorIr,
754}
755
756#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
757#[allow(missing_docs)]
758pub struct MatmulOpIr {
759    pub lhs: TensorIr,
760    pub rhs: TensorIr,
761    pub out: TensorIr,
762}
763
764#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
765#[allow(missing_docs)]
766pub struct CrossOpIr {
767    pub lhs: TensorIr,
768    pub rhs: TensorIr,
769    pub out: TensorIr,
770    pub dim: usize,
771}
772
773#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
774#[allow(missing_docs)]
775pub struct UnaryOpIr {
776    pub input: TensorIr,
777    pub out: TensorIr,
778}
779
780#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
781#[allow(missing_docs)]
782pub struct ScalarOpIr {
783    pub lhs: TensorIr,
784    // TODO: Make that an enum with `Value` and `Id` variants for relative/global
785    // conversion.
786    pub rhs: ScalarIr,
787    pub out: TensorIr,
788}
789
790#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]
791#[allow(missing_docs)]
792pub struct ReduceOpIr {
793    pub input: TensorIr,
794    pub out: TensorIr,
795}
796
797#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]
798#[allow(missing_docs)]
799pub struct ReduceDimOpIr {
800    pub input: TensorIr,
801    pub out: TensorIr,
802    pub axis: usize,
803}
804
805#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
806#[allow(missing_docs)]
807pub struct CastOpIr {
808    pub input: TensorIr,
809    pub out: TensorIr,
810}
811
812/// IR for operations that operate along a dimension without reducing it.
813/// Unlike `ReduceDimOpIr`, the output shape is the same as the input shape.
814#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]
815#[allow(missing_docs)]
816pub struct DimOpIr {
817    pub input: TensorIr,
818    pub out: TensorIr,
819    pub axis: usize,
820}
821
822#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
823#[allow(missing_docs)]
824pub struct GatherOpIr {
825    pub tensor: TensorIr,
826    pub dim: usize,
827    pub indices: TensorIr,
828    pub out: TensorIr,
829}
830
831#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
832#[allow(missing_docs)]
833pub struct ScatterOpIr {
834    pub tensor: TensorIr,
835    pub dim: usize,
836    pub indices: TensorIr,
837    pub value: TensorIr,
838    pub update: IndexingUpdateOp,
839    pub out: TensorIr,
840}
841
842#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
843#[allow(missing_docs)]
844pub struct SelectOpIr {
845    pub tensor: TensorIr,
846    pub dim: usize,
847    pub indices: TensorIr,
848    pub out: TensorIr,
849}
850
851#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
852#[allow(missing_docs)]
853pub struct SelectAssignOpIr {
854    pub tensor: TensorIr,
855    pub dim: usize,
856    pub indices: TensorIr,
857    pub value: TensorIr,
858    pub update: IndexingUpdateOp,
859    pub out: TensorIr,
860}
861
862#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
863#[allow(missing_docs)]
864pub struct SliceOpIr {
865    pub tensor: TensorIr,
866    pub ranges: Vec<Slice>,
867    pub out: TensorIr,
868}
869
870#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
871#[allow(missing_docs)]
872pub struct SliceAssignOpIr {
873    pub tensor: TensorIr,
874    pub ranges: Vec<burn_backend::Slice>,
875    pub value: TensorIr,
876    pub out: TensorIr,
877}
878
879#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
880#[allow(missing_docs)]
881pub struct MaskWhereOpIr {
882    pub tensor: TensorIr,
883    pub mask: TensorIr,
884    pub value: TensorIr,
885    pub out: TensorIr,
886}
887
888#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
889#[allow(missing_docs)]
890pub struct MaskFillOpIr {
891    pub tensor: TensorIr,
892    pub mask: TensorIr,
893    pub value: ScalarIr,
894    pub out: TensorIr,
895}
896
897#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
898#[allow(missing_docs)]
899pub struct ClampOpIr {
900    pub tensor: TensorIr,
901    pub min: ScalarIr,
902    pub max: ScalarIr,
903    pub out: TensorIr,
904}
905
906#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
907#[allow(missing_docs)]
908pub struct RepeatDimOpIr {
909    pub tensor: TensorIr,
910    pub dim: usize,
911    pub times: usize,
912    pub out: TensorIr,
913}
914
915#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
916#[allow(missing_docs)]
917pub struct CatOpIr {
918    pub tensors: Vec<TensorIr>,
919    pub dim: usize,
920    pub out: TensorIr,
921}
922
923#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
924#[allow(missing_docs)]
925pub struct ReduceDimWithIndicesOpIr {
926    pub tensor: TensorIr,
927    pub dim: usize,
928    pub out: TensorIr,
929    pub out_indices: TensorIr,
930}
931
932#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
933#[allow(missing_docs)]
934pub struct EmbeddingOpIr {
935    pub weights: TensorIr,
936    pub indices: TensorIr,
937    pub out: TensorIr,
938}
939
940#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
941#[allow(missing_docs)]
942pub struct EmbeddingBackwardOpIr {
943    pub weights: TensorIr,
944    pub out_grad: TensorIr,
945    pub indices: TensorIr,
946    pub out: TensorIr,
947}
948
949#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
950#[allow(missing_docs)]
951pub struct Conv1dOpIr {
952    pub x: TensorIr,
953    pub weight: TensorIr,
954    pub bias: Option<TensorIr>,
955    pub options: Conv1dOptionsIr,
956    pub out: TensorIr,
957}
958
959#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
960#[allow(missing_docs)]
961pub struct Conv2dOpIr {
962    pub x: TensorIr,
963    pub weight: TensorIr,
964    pub bias: Option<TensorIr>,
965    pub options: Conv2dOptionsIr,
966    pub out: TensorIr,
967}
968
969#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
970#[allow(missing_docs)]
971pub struct DeformConv2dOpIr {
972    pub x: TensorIr,
973    pub offset: TensorIr,
974    pub weight: TensorIr,
975    pub mask: Option<TensorIr>,
976    pub bias: Option<TensorIr>,
977    pub options: DeformableConv2dOptionsIr,
978    pub out: TensorIr,
979}
980
981#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
982#[allow(missing_docs)]
983pub struct DeformConv2dBackwardOpIr {
984    pub x: TensorIr,
985    pub offset: TensorIr,
986    pub weight: TensorIr,
987    pub mask: Option<TensorIr>,
988    pub bias: Option<TensorIr>,
989    pub out_grad: TensorIr,
990    pub options: DeformableConv2dOptionsIr,
991    pub input_grad: TensorIr,
992    pub offset_grad: TensorIr,
993    pub weight_grad: TensorIr,
994    pub mask_grad: Option<TensorIr>,
995    pub bias_grad: Option<TensorIr>,
996}
997
998#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
999#[allow(missing_docs)]
1000pub struct Conv3dOpIr {
1001    pub x: TensorIr,
1002    pub weight: TensorIr,
1003    pub bias: Option<TensorIr>,
1004    pub options: Conv3dOptionsIr,
1005    pub out: TensorIr,
1006}
1007
1008#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1009#[allow(missing_docs)]
1010pub struct ConvTranspose1dOpIr {
1011    pub x: TensorIr,
1012    pub weight: TensorIr,
1013    pub bias: Option<TensorIr>,
1014    pub options: ConvTranspose1dOptionsIr,
1015    pub out: TensorIr,
1016}
1017
1018#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1019#[allow(missing_docs)]
1020pub struct ConvTranspose2dOpIr {
1021    pub x: TensorIr,
1022    pub weight: TensorIr,
1023    pub bias: Option<TensorIr>,
1024    pub options: ConvTranspose2dOptionsIr,
1025    pub out: TensorIr,
1026}
1027
1028#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1029#[allow(missing_docs)]
1030pub struct ConvTranspose3dOpIr {
1031    pub x: TensorIr,
1032    pub weight: TensorIr,
1033    pub bias: Option<TensorIr>,
1034    pub options: ConvTranspose3dOptionsIr,
1035    pub out: TensorIr,
1036}
1037
1038#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1039#[allow(missing_docs)]
1040pub struct Conv1dOptionsIr {
1041    pub stride: [usize; 1],
1042    pub padding: [usize; 1],
1043    pub dilation: [usize; 1],
1044    pub groups: usize,
1045}
1046
1047#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1048#[allow(missing_docs)]
1049pub struct Conv2dOptionsIr {
1050    pub stride: [usize; 2],
1051    pub padding: [usize; 2],
1052    pub dilation: [usize; 2],
1053    pub groups: usize,
1054}
1055
1056#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1057#[allow(missing_docs)]
1058pub struct DeformableConv2dOptionsIr {
1059    pub stride: [usize; 2],
1060    pub padding: [usize; 2],
1061    pub dilation: [usize; 2],
1062    pub weight_groups: usize,
1063    pub offset_groups: usize,
1064}
1065
1066#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1067#[allow(missing_docs)]
1068pub struct Conv3dOptionsIr {
1069    pub stride: [usize; 3],
1070    pub padding: [usize; 3],
1071    pub dilation: [usize; 3],
1072    pub groups: usize,
1073}
1074
1075#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1076#[allow(missing_docs)]
1077pub struct ConvTranspose1dOptionsIr {
1078    pub stride: [usize; 1],
1079    pub padding: [usize; 1],
1080    pub padding_out: [usize; 1],
1081    pub dilation: [usize; 1],
1082    pub groups: usize,
1083}
1084
1085#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1086#[allow(missing_docs)]
1087pub struct ConvTranspose2dOptionsIr {
1088    pub stride: [usize; 2],
1089    pub padding: [usize; 2],
1090    pub padding_out: [usize; 2],
1091    pub dilation: [usize; 2],
1092    pub groups: usize,
1093}
1094
1095#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1096#[allow(missing_docs)]
1097pub struct ConvTranspose3dOptionsIr {
1098    pub stride: [usize; 3],
1099    pub padding: [usize; 3],
1100    pub padding_out: [usize; 3],
1101    pub dilation: [usize; 3],
1102    pub groups: usize,
1103}
1104
1105/// Quantization parameters intermediate representation.
1106#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
1107pub struct QuantizationParametersIr {
1108    /// The scaling factor.
1109    pub scales: TensorIr,
1110}
1111
1112#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1113#[allow(missing_docs)]
1114pub struct QuantizeOpIr {
1115    pub tensor: TensorIr,
1116    pub qparams: QuantizationParametersIr,
1117    pub scheme: QuantScheme,
1118    pub out: TensorIr,
1119}
1120
1121#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1122#[allow(missing_docs)]
1123pub struct DequantizeOpIr {
1124    pub input: TensorIr,
1125    pub out: TensorIr,
1126}
1127
1128impl From<ConvOptions<1>> for Conv1dOptionsIr {
1129    fn from(value: ConvOptions<1>) -> Self {
1130        Self {
1131            stride: value.stride,
1132            padding: value.padding,
1133            dilation: value.dilation,
1134            groups: value.groups,
1135        }
1136    }
1137}
1138
1139impl From<ConvOptions<2>> for Conv2dOptionsIr {
1140    fn from(value: ConvOptions<2>) -> Self {
1141        Self {
1142            stride: value.stride,
1143            padding: value.padding,
1144            dilation: value.dilation,
1145            groups: value.groups,
1146        }
1147    }
1148}
1149
1150impl From<ConvOptions<3>> for Conv3dOptionsIr {
1151    fn from(value: ConvOptions<3>) -> Self {
1152        Self {
1153            stride: value.stride,
1154            padding: value.padding,
1155            dilation: value.dilation,
1156            groups: value.groups,
1157        }
1158    }
1159}
1160
1161impl From<DeformConvOptions<2>> for DeformableConv2dOptionsIr {
1162    fn from(value: DeformConvOptions<2>) -> Self {
1163        Self {
1164            stride: value.stride,
1165            padding: value.padding,
1166            dilation: value.dilation,
1167            weight_groups: value.weight_groups,
1168            offset_groups: value.offset_groups,
1169        }
1170    }
1171}
1172
1173impl From<ConvTransposeOptions<1>> for ConvTranspose1dOptionsIr {
1174    fn from(value: ConvTransposeOptions<1>) -> Self {
1175        Self {
1176            stride: value.stride,
1177            padding: value.padding,
1178            padding_out: value.padding_out,
1179            dilation: value.dilation,
1180            groups: value.groups,
1181        }
1182    }
1183}
1184
1185impl From<ConvTransposeOptions<2>> for ConvTranspose2dOptionsIr {
1186    fn from(value: ConvTransposeOptions<2>) -> Self {
1187        Self {
1188            stride: value.stride,
1189            padding: value.padding,
1190            padding_out: value.padding_out,
1191            dilation: value.dilation,
1192            groups: value.groups,
1193        }
1194    }
1195}
1196
1197impl From<ConvTransposeOptions<3>> for ConvTranspose3dOptionsIr {
1198    fn from(value: ConvTransposeOptions<3>) -> Self {
1199        Self {
1200            stride: value.stride,
1201            padding: value.padding,
1202            padding_out: value.padding_out,
1203            dilation: value.dilation,
1204            groups: value.groups,
1205        }
1206    }
1207}
1208
1209impl From<Conv1dOptionsIr> for ConvOptions<1> {
1210    fn from(val: Conv1dOptionsIr) -> Self {
1211        ConvOptions {
1212            stride: val.stride,
1213            padding: val.padding,
1214            dilation: val.dilation,
1215            groups: val.groups,
1216        }
1217    }
1218}
1219
1220impl From<Conv2dOptionsIr> for ConvOptions<2> {
1221    fn from(val: Conv2dOptionsIr) -> Self {
1222        ConvOptions {
1223            stride: val.stride,
1224            padding: val.padding,
1225            dilation: val.dilation,
1226            groups: val.groups,
1227        }
1228    }
1229}
1230
1231impl From<Conv3dOptionsIr> for ConvOptions<3> {
1232    fn from(val: Conv3dOptionsIr) -> Self {
1233        ConvOptions {
1234            stride: val.stride,
1235            padding: val.padding,
1236            dilation: val.dilation,
1237            groups: val.groups,
1238        }
1239    }
1240}
1241
1242impl From<DeformableConv2dOptionsIr> for DeformConvOptions<2> {
1243    fn from(value: DeformableConv2dOptionsIr) -> Self {
1244        DeformConvOptions {
1245            stride: value.stride,
1246            padding: value.padding,
1247            dilation: value.dilation,
1248            weight_groups: value.weight_groups,
1249            offset_groups: value.offset_groups,
1250        }
1251    }
1252}
1253
1254impl From<ConvTranspose1dOptionsIr> for ConvTransposeOptions<1> {
1255    fn from(val: ConvTranspose1dOptionsIr) -> Self {
1256        ConvTransposeOptions {
1257            stride: val.stride,
1258            padding: val.padding,
1259            padding_out: val.padding_out,
1260            dilation: val.dilation,
1261            groups: val.groups,
1262        }
1263    }
1264}
1265
1266impl From<ConvTranspose2dOptionsIr> for ConvTransposeOptions<2> {
1267    fn from(val: ConvTranspose2dOptionsIr) -> Self {
1268        ConvTransposeOptions {
1269            stride: val.stride,
1270            padding: val.padding,
1271            padding_out: val.padding_out,
1272            dilation: val.dilation,
1273            groups: val.groups,
1274        }
1275    }
1276}
1277
1278impl From<ConvTranspose3dOptionsIr> for ConvTransposeOptions<3> {
1279    fn from(val: ConvTranspose3dOptionsIr) -> Self {
1280        ConvTransposeOptions {
1281            stride: val.stride,
1282            padding: val.padding,
1283            padding_out: val.padding_out,
1284            dilation: val.dilation,
1285            groups: val.groups,
1286        }
1287    }
1288}
1289
1290#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1291#[allow(missing_docs)]
1292pub struct AvgPool1dOpIr {
1293    pub x: TensorIr,
1294    pub kernel_size: usize,
1295    pub stride: usize,
1296    pub padding: usize,
1297    pub count_include_pad: bool,
1298    pub ceil_mode: bool,
1299    pub out: TensorIr,
1300}
1301
1302#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1303#[allow(missing_docs)]
1304pub struct AvgPool2dOpIr {
1305    pub x: TensorIr,
1306    pub kernel_size: [usize; 2],
1307    pub stride: [usize; 2],
1308    pub padding: [usize; 2],
1309    pub count_include_pad: bool,
1310    pub ceil_mode: bool,
1311    pub out: TensorIr,
1312}
1313
1314#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1315#[allow(missing_docs)]
1316pub struct AvgPool1dBackwardOpIr {
1317    pub x: TensorIr,
1318    pub grad: TensorIr,
1319    pub kernel_size: usize,
1320    pub stride: usize,
1321    pub padding: usize,
1322    pub count_include_pad: bool,
1323    pub ceil_mode: bool,
1324    pub out: TensorIr,
1325}
1326
1327#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1328#[allow(missing_docs)]
1329pub struct AvgPool2dBackwardOpIr {
1330    pub x: TensorIr,
1331    pub grad: TensorIr,
1332    pub kernel_size: [usize; 2],
1333    pub stride: [usize; 2],
1334    pub padding: [usize; 2],
1335    pub count_include_pad: bool,
1336    pub ceil_mode: bool,
1337    pub out: TensorIr,
1338}
1339
1340#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1341#[allow(missing_docs)]
1342pub struct AdaptiveAvgPool1dOpIr {
1343    pub x: TensorIr,
1344    pub output_size: usize,
1345    pub out: TensorIr,
1346}
1347
1348#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1349#[allow(missing_docs)]
1350pub struct AdaptiveAvgPool2dOpIr {
1351    pub x: TensorIr,
1352    pub output_size: [usize; 2],
1353    pub out: TensorIr,
1354}
1355
1356#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1357#[allow(missing_docs)]
1358pub struct AdaptiveAvgPool1dBackwardOpIr {
1359    pub x: TensorIr,
1360    pub grad: TensorIr,
1361    pub out: TensorIr,
1362}
1363
1364#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1365#[allow(missing_docs)]
1366pub struct AdaptiveAvgPool2dBackwardOpIr {
1367    pub x: TensorIr,
1368    pub grad: TensorIr,
1369    pub out: TensorIr,
1370}
1371
1372#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1373#[allow(missing_docs)]
1374pub struct MaxPool1dOpIr {
1375    pub x: TensorIr,
1376    pub kernel_size: usize,
1377    pub stride: usize,
1378    pub padding: usize,
1379    pub dilation: usize,
1380    pub ceil_mode: bool,
1381    pub out: TensorIr,
1382}
1383
1384#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1385#[allow(missing_docs)]
1386pub struct MaxPool1dWithIndicesOpIr {
1387    pub x: TensorIr,
1388    pub kernel_size: usize,
1389    pub stride: usize,
1390    pub padding: usize,
1391    pub dilation: usize,
1392    pub ceil_mode: bool,
1393    pub out: TensorIr,
1394    pub out_indices: TensorIr,
1395}
1396
1397#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1398#[allow(missing_docs)]
1399pub struct MaxPool1dWithIndicesBackwardOpIr {
1400    pub x: TensorIr,
1401    pub grad: TensorIr,
1402    pub indices: TensorIr,
1403    pub kernel_size: usize,
1404    pub stride: usize,
1405    pub padding: usize,
1406    pub dilation: usize,
1407    pub ceil_mode: bool,
1408    pub out: TensorIr,
1409}
1410
1411#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1412#[allow(missing_docs)]
1413pub struct MaxPool2dOpIr {
1414    pub x: TensorIr,
1415    pub kernel_size: [usize; 2],
1416    pub stride: [usize; 2],
1417    pub padding: [usize; 2],
1418    pub dilation: [usize; 2],
1419    pub ceil_mode: bool,
1420    pub out: TensorIr,
1421}
1422
1423#[allow(missing_docs)]
1424#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1425pub struct MaxPool2dWithIndicesOpIr {
1426    pub x: TensorIr,
1427    pub kernel_size: [usize; 2],
1428    pub stride: [usize; 2],
1429    pub padding: [usize; 2],
1430    pub dilation: [usize; 2],
1431    pub ceil_mode: bool,
1432    pub out: TensorIr,
1433    pub out_indices: TensorIr,
1434}
1435
1436#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1437#[allow(missing_docs)]
1438pub struct MaxPool2dWithIndicesBackwardOpIr {
1439    pub x: TensorIr,
1440    pub grad: TensorIr,
1441    pub indices: TensorIr,
1442    pub kernel_size: [usize; 2],
1443    pub stride: [usize; 2],
1444    pub padding: [usize; 2],
1445    pub dilation: [usize; 2],
1446    pub ceil_mode: bool,
1447    pub out: TensorIr,
1448}
1449
1450#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1451#[allow(missing_docs)]
1452pub enum InterpolateModeIr {
1453    Nearest,
1454    Bilinear,
1455    Bicubic,
1456}
1457
1458#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1459#[allow(missing_docs)]
1460pub struct InterpolateOptionsIr {
1461    pub mode: InterpolateModeIr,
1462}
1463
1464#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1465#[allow(missing_docs)]
1466pub struct InterpolateOpIr {
1467    pub x: TensorIr,
1468    pub output_size: [usize; 2],
1469    pub options: InterpolateOptionsIr,
1470    pub out: TensorIr,
1471}
1472
1473impl From<InterpolateModeIr> for InterpolateMode {
1474    fn from(val: InterpolateModeIr) -> Self {
1475        match val {
1476            InterpolateModeIr::Nearest => Self::Nearest,
1477            InterpolateModeIr::Bilinear => Self::Bilinear,
1478            InterpolateModeIr::Bicubic => Self::Bicubic,
1479        }
1480    }
1481}
1482
1483impl From<InterpolateOptionsIr> for InterpolateOptions {
1484    fn from(val: InterpolateOptionsIr) -> Self {
1485        Self {
1486            mode: val.mode.into(),
1487        }
1488    }
1489}
1490
1491impl From<InterpolateMode> for InterpolateModeIr {
1492    fn from(val: InterpolateMode) -> Self {
1493        match val {
1494            InterpolateMode::Nearest => Self::Nearest,
1495            InterpolateMode::Bilinear => Self::Bilinear,
1496            InterpolateMode::Bicubic => Self::Bicubic,
1497        }
1498    }
1499}
1500
1501impl From<InterpolateOptions> for InterpolateOptionsIr {
1502    fn from(val: InterpolateOptions) -> Self {
1503        Self {
1504            mode: val.mode.into(),
1505        }
1506    }
1507}
1508
1509#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1510#[allow(missing_docs)]
1511pub struct InterpolateBackwardOpIr {
1512    pub x: TensorIr,
1513    pub grad: TensorIr,
1514    pub output_size: [usize; 2],
1515    pub options: InterpolateOptionsIr,
1516    pub out: TensorIr,
1517}
1518
1519#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1520#[allow(missing_docs)]
1521pub enum GridSamplePaddingModeIr {
1522    Zeros,
1523    Border,
1524    Reflection,
1525}
1526
1527#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1528#[allow(missing_docs)]
1529pub struct GridSampleOptionsIr {
1530    pub mode: InterpolateModeIr,
1531    pub padding_mode: GridSamplePaddingModeIr,
1532    pub align_corners: bool,
1533}
1534
1535#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1536#[allow(missing_docs)]
1537pub struct GridSample2dOpIr {
1538    pub tensor: TensorIr,
1539    pub grid: TensorIr,
1540    pub options: GridSampleOptionsIr,
1541    pub out: TensorIr,
1542}
1543
1544impl From<GridSamplePaddingModeIr> for GridSamplePaddingMode {
1545    fn from(val: GridSamplePaddingModeIr) -> Self {
1546        match val {
1547            GridSamplePaddingModeIr::Zeros => Self::Zeros,
1548            GridSamplePaddingModeIr::Border => Self::Border,
1549            GridSamplePaddingModeIr::Reflection => Self::Reflection,
1550        }
1551    }
1552}
1553
1554impl From<GridSamplePaddingMode> for GridSamplePaddingModeIr {
1555    fn from(val: GridSamplePaddingMode) -> Self {
1556        match val {
1557            GridSamplePaddingMode::Zeros => Self::Zeros,
1558            GridSamplePaddingMode::Border => Self::Border,
1559            GridSamplePaddingMode::Reflection => Self::Reflection,
1560        }
1561    }
1562}
1563
1564impl From<GridSampleOptionsIr> for GridSampleOptions {
1565    fn from(val: GridSampleOptionsIr) -> Self {
1566        Self {
1567            mode: val.mode.into(),
1568            padding_mode: val.padding_mode.into(),
1569            align_corners: val.align_corners,
1570        }
1571    }
1572}
1573
1574impl From<GridSampleOptions> for GridSampleOptionsIr {
1575    fn from(val: GridSampleOptions) -> Self {
1576        Self {
1577            mode: val.mode.into(),
1578            padding_mode: val.padding_mode.into(),
1579            align_corners: val.align_corners,
1580        }
1581    }
1582}
1583
1584impl OperationIr {
1585    /// Get all input [tensors](TensorIr) involved with the current operation.
1586    pub fn inputs(&self) -> impl Iterator<Item = &TensorIr> {
1587        match self {
1588            OperationIr::BaseFloat(repr) => repr.inputs(),
1589            OperationIr::BaseInt(repr) => repr.inputs(),
1590            OperationIr::BaseBool(repr) => repr.inputs(),
1591            OperationIr::NumericFloat(_dtype, repr) => repr.inputs(),
1592            OperationIr::NumericInt(_dtype, repr) => repr.inputs(),
1593            OperationIr::Bool(repr) => repr.inputs(),
1594            OperationIr::Int(repr) => repr.inputs(),
1595            OperationIr::Float(_dtype, repr) => repr.inputs(),
1596            OperationIr::Module(repr) => repr.inputs(),
1597            OperationIr::Init(repr) => repr.inputs(),
1598            OperationIr::Custom(repr) => repr.inputs(),
1599            OperationIr::Drop(repr) => Box::new([repr].into_iter()),
1600        }
1601    }
1602
1603    /// Get all output [tensors](TensorIr) involved with the current operation.
1604    pub fn outputs(&self) -> impl Iterator<Item = &TensorIr> {
1605        match self {
1606            OperationIr::BaseFloat(repr) => repr.outputs(),
1607            OperationIr::BaseInt(repr) => repr.outputs(),
1608            OperationIr::BaseBool(repr) => repr.outputs(),
1609            OperationIr::NumericFloat(_dtype, repr) => repr.outputs(),
1610            OperationIr::NumericInt(_dtype, repr) => repr.outputs(),
1611            OperationIr::Bool(repr) => repr.outputs(),
1612            OperationIr::Int(repr) => repr.outputs(),
1613            OperationIr::Float(_dtype, repr) => repr.outputs(),
1614            OperationIr::Module(repr) => repr.outputs(),
1615            OperationIr::Init(repr) => repr.outputs(),
1616            OperationIr::Custom(repr) => repr.outputs(),
1617            OperationIr::Drop(_repr) => Box::new([].into_iter()),
1618        }
1619    }
1620
1621    /// Get all [tensor](TensorIr) involved with the current operation.
1622    pub fn nodes(&self) -> Vec<&TensorIr> {
1623        self.inputs().chain(self.outputs()).collect()
1624    }
1625
1626    /// Set the given nodes that are [read write](super::TensorStatus::ReadWrite) to
1627    /// [read only](super::TensorStatus::ReadOnly) in the current operation.
1628    ///
1629    /// Returns the tensor that were updated with their original representation.
1630    pub fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
1631        match self {
1632            OperationIr::BaseFloat(repr) => repr.mark_read_only(nodes),
1633            OperationIr::BaseInt(repr) => repr.mark_read_only(nodes),
1634            OperationIr::BaseBool(repr) => repr.mark_read_only(nodes),
1635            OperationIr::NumericFloat(_dtype, repr) => repr.mark_read_only(nodes),
1636            OperationIr::NumericInt(_dtype, repr) => repr.mark_read_only(nodes),
1637            OperationIr::Bool(repr) => repr.mark_read_only(nodes),
1638            OperationIr::Int(repr) => repr.mark_read_only(nodes),
1639            OperationIr::Float(_dtype, repr) => repr.mark_read_only(nodes),
1640            OperationIr::Module(repr) => repr.mark_read_only(nodes),
1641            OperationIr::Init(_) => Vec::new(),
1642            OperationIr::Drop(repr) => {
1643                let mut output = Vec::new();
1644                repr.mark_read_only(nodes, &mut output);
1645                output
1646            }
1647            OperationIr::Custom(repr) => {
1648                let mut output = Vec::new();
1649
1650                for input in repr.inputs.iter_mut() {
1651                    input.mark_read_only(nodes, &mut output);
1652                }
1653
1654                output
1655            }
1656        }
1657    }
1658}
1659
1660impl BaseOperationIr {
1661    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
1662        match self {
1663            BaseOperationIr::Reshape(repr) => Box::new([&repr.input].into_iter()),
1664            BaseOperationIr::SwapDims(repr) => Box::new([&repr.input].into_iter()),
1665            BaseOperationIr::Permute(repr) => Box::new([&repr.input].into_iter()),
1666            BaseOperationIr::Expand(repr) => Box::new([&repr.input].into_iter()),
1667            BaseOperationIr::Flip(repr) => Box::new([&repr.input].into_iter()),
1668            BaseOperationIr::Slice(repr) => Box::new([&repr.tensor].into_iter()),
1669            BaseOperationIr::SliceAssign(repr) => Box::new([&repr.tensor, &repr.value].into_iter()),
1670            BaseOperationIr::Gather(repr) => Box::new([&repr.tensor, &repr.indices].into_iter()),
1671            BaseOperationIr::Scatter(repr) => {
1672                Box::new([&repr.tensor, &repr.indices, &repr.value].into_iter())
1673            }
1674            BaseOperationIr::Select(repr) => Box::new([&repr.tensor, &repr.indices].into_iter()),
1675            BaseOperationIr::SelectAssign(repr) => {
1676                Box::new([&repr.tensor, &repr.indices, &repr.value].into_iter())
1677            }
1678            BaseOperationIr::MaskWhere(repr) => {
1679                Box::new([&repr.tensor, &repr.mask, &repr.value].into_iter())
1680            }
1681            BaseOperationIr::MaskFill(repr) => Box::new([&repr.tensor, &repr.mask].into_iter()),
1682            BaseOperationIr::Equal(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
1683            BaseOperationIr::EqualElem(repr) => Box::new([&repr.lhs].into_iter()),
1684            BaseOperationIr::RepeatDim(repr) => Box::new([&repr.tensor].into_iter()),
1685            BaseOperationIr::Cat(repr) => Box::new(repr.tensors.iter()),
1686            BaseOperationIr::Cast(repr) => Box::new([&repr.input].into_iter()),
1687            BaseOperationIr::Unfold(repr) => Box::new([&repr.input].into_iter()),
1688            BaseOperationIr::Empty(_repr) => Box::new([].into_iter()),
1689            BaseOperationIr::Ones(_repr) => Box::new([].into_iter()),
1690            BaseOperationIr::Zeros(_repr) => Box::new([].into_iter()),
1691        }
1692    }
1693
1694    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
1695        match self {
1696            BaseOperationIr::Reshape(repr) => Box::new([&repr.out].into_iter()),
1697            BaseOperationIr::SwapDims(repr) => Box::new([&repr.out].into_iter()),
1698            BaseOperationIr::Permute(repr) => Box::new([&repr.out].into_iter()),
1699            BaseOperationIr::Expand(repr) => Box::new([&repr.out].into_iter()),
1700            BaseOperationIr::Flip(repr) => Box::new([&repr.out].into_iter()),
1701            BaseOperationIr::Slice(repr) => Box::new([&repr.out].into_iter()),
1702            BaseOperationIr::SliceAssign(repr) => Box::new([&repr.out].into_iter()),
1703            BaseOperationIr::Gather(repr) => Box::new([&repr.out].into_iter()),
1704            BaseOperationIr::Scatter(repr) => Box::new([&repr.out].into_iter()),
1705            BaseOperationIr::Select(repr) => Box::new([&repr.out].into_iter()),
1706            BaseOperationIr::SelectAssign(repr) => Box::new([&repr.out].into_iter()),
1707            BaseOperationIr::MaskWhere(repr) => Box::new([&repr.out].into_iter()),
1708            BaseOperationIr::MaskFill(repr) => Box::new([&repr.out].into_iter()),
1709            BaseOperationIr::Equal(repr) => Box::new([&repr.out].into_iter()),
1710            BaseOperationIr::EqualElem(repr) => Box::new([&repr.out].into_iter()),
1711            BaseOperationIr::RepeatDim(repr) => Box::new([&repr.out].into_iter()),
1712            BaseOperationIr::Cat(repr) => Box::new([&repr.out].into_iter()),
1713            BaseOperationIr::Cast(repr) => Box::new([&repr.out].into_iter()),
1714            BaseOperationIr::Unfold(repr) => Box::new([&repr.out].into_iter()),
1715            BaseOperationIr::Empty(repr) => Box::new([&repr.out].into_iter()),
1716            BaseOperationIr::Ones(repr) => Box::new([&repr.out].into_iter()),
1717            BaseOperationIr::Zeros(repr) => Box::new([&repr.out].into_iter()),
1718        }
1719    }
1720
1721    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
1722        let mut output = Vec::new();
1723
1724        match self {
1725            BaseOperationIr::Reshape(repr) => {
1726                repr.input.mark_read_only(nodes, &mut output);
1727            }
1728            BaseOperationIr::SwapDims(repr) => {
1729                repr.input.mark_read_only(nodes, &mut output);
1730            }
1731            BaseOperationIr::Permute(repr) => {
1732                repr.input.mark_read_only(nodes, &mut output);
1733            }
1734
1735            BaseOperationIr::Expand(repr) => {
1736                repr.input.mark_read_only(nodes, &mut output);
1737            }
1738
1739            BaseOperationIr::Flip(repr) => {
1740                repr.input.mark_read_only(nodes, &mut output);
1741            }
1742            BaseOperationIr::Slice(repr) => {
1743                repr.tensor.mark_read_only(nodes, &mut output);
1744            }
1745            BaseOperationIr::SliceAssign(repr) => {
1746                repr.tensor.mark_read_only(nodes, &mut output);
1747                repr.value.mark_read_only(nodes, &mut output);
1748            }
1749            BaseOperationIr::Gather(repr) => {
1750                repr.tensor.mark_read_only(nodes, &mut output);
1751                repr.indices.mark_read_only(nodes, &mut output);
1752            }
1753            BaseOperationIr::Scatter(repr) => {
1754                repr.tensor.mark_read_only(nodes, &mut output);
1755                repr.indices.mark_read_only(nodes, &mut output);
1756                repr.value.mark_read_only(nodes, &mut output);
1757            }
1758            BaseOperationIr::Select(repr) => {
1759                repr.tensor.mark_read_only(nodes, &mut output);
1760                repr.indices.mark_read_only(nodes, &mut output);
1761            }
1762            BaseOperationIr::SelectAssign(repr) => {
1763                repr.tensor.mark_read_only(nodes, &mut output);
1764                repr.indices.mark_read_only(nodes, &mut output);
1765                repr.value.mark_read_only(nodes, &mut output);
1766            }
1767            BaseOperationIr::MaskWhere(repr) => {
1768                repr.tensor.mark_read_only(nodes, &mut output);
1769                repr.mask.mark_read_only(nodes, &mut output);
1770                repr.value.mark_read_only(nodes, &mut output);
1771            }
1772            BaseOperationIr::MaskFill(repr) => {
1773                repr.tensor.mark_read_only(nodes, &mut output);
1774                repr.mask.mark_read_only(nodes, &mut output);
1775            }
1776            BaseOperationIr::Equal(repr) => {
1777                repr.lhs.mark_read_only(nodes, &mut output);
1778                repr.rhs.mark_read_only(nodes, &mut output);
1779            }
1780            BaseOperationIr::EqualElem(repr) => {
1781                repr.lhs.mark_read_only(nodes, &mut output);
1782            }
1783            BaseOperationIr::RepeatDim(repr) => {
1784                repr.tensor.mark_read_only(nodes, &mut output);
1785            }
1786            BaseOperationIr::Cat(repr) => {
1787                for t in repr.tensors.iter_mut() {
1788                    t.mark_read_only(nodes, &mut output);
1789                }
1790            }
1791            BaseOperationIr::Cast(repr) => {
1792                repr.input.mark_read_only(nodes, &mut output);
1793            }
1794            BaseOperationIr::Unfold(repr) => {
1795                repr.input.mark_read_only(nodes, &mut output);
1796            }
1797            BaseOperationIr::Empty(_) => {}
1798            BaseOperationIr::Zeros(_) => {}
1799            BaseOperationIr::Ones(_) => {}
1800        };
1801
1802        output
1803    }
1804}
1805
1806impl NumericOperationIr {
1807    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
1808        match self {
1809            NumericOperationIr::Add(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
1810            NumericOperationIr::AddScalar(repr) => Box::new([&repr.lhs].into_iter()),
1811            NumericOperationIr::Sub(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
1812            NumericOperationIr::SubScalar(repr) => Box::new([&repr.lhs].into_iter()),
1813            NumericOperationIr::Mul(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
1814            NumericOperationIr::MulScalar(repr) => Box::new([&repr.lhs].into_iter()),
1815            NumericOperationIr::Div(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
1816            NumericOperationIr::DivScalar(repr) => Box::new([&repr.lhs].into_iter()),
1817            NumericOperationIr::Rem(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
1818            NumericOperationIr::RemScalar(repr) => Box::new([&repr.lhs].into_iter()),
1819            NumericOperationIr::GreaterElem(repr) => Box::new([&repr.lhs].into_iter()),
1820            NumericOperationIr::GreaterEqualElem(repr) => Box::new([&repr.lhs].into_iter()),
1821            NumericOperationIr::LowerElem(repr) => Box::new([&repr.lhs].into_iter()),
1822            NumericOperationIr::LowerEqualElem(repr) => Box::new([&repr.lhs].into_iter()),
1823            NumericOperationIr::Greater(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
1824            NumericOperationIr::GreaterEqual(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
1825            NumericOperationIr::Lower(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
1826            NumericOperationIr::LowerEqual(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
1827            NumericOperationIr::ArgMax(repr) => Box::new([&repr.input].into_iter()),
1828            NumericOperationIr::ArgMin(repr) => Box::new([&repr.input].into_iter()),
1829            NumericOperationIr::Clamp(repr) => Box::new([&repr.tensor].into_iter()),
1830            NumericOperationIr::Abs(repr) => Box::new([&repr.input].into_iter()),
1831            NumericOperationIr::Full(_repr) => Box::new([].into_iter()),
1832            NumericOperationIr::MeanDim(repr) => Box::new([&repr.input].into_iter()),
1833            NumericOperationIr::Mean(repr) => Box::new([&repr.input].into_iter()),
1834            NumericOperationIr::Sum(repr) => Box::new([&repr.input].into_iter()),
1835            NumericOperationIr::SumDim(repr) => Box::new([&repr.input].into_iter()),
1836            NumericOperationIr::Prod(repr) => Box::new([&repr.input].into_iter()),
1837            NumericOperationIr::ProdDim(repr) => Box::new([&repr.input].into_iter()),
1838            NumericOperationIr::Max(repr) => Box::new([&repr.input].into_iter()),
1839            NumericOperationIr::MaxDimWithIndices(repr) => Box::new([&repr.tensor].into_iter()),
1840            NumericOperationIr::MinDimWithIndices(repr) => Box::new([&repr.tensor].into_iter()),
1841            NumericOperationIr::Min(repr) => Box::new([&repr.input].into_iter()),
1842            NumericOperationIr::MaxDim(repr) => Box::new([&repr.input].into_iter()),
1843            NumericOperationIr::MinDim(repr) => Box::new([&repr.input].into_iter()),
1844            NumericOperationIr::MaxAbs(repr) => Box::new([&repr.input].into_iter()),
1845            NumericOperationIr::MaxAbsDim(repr) => Box::new([&repr.input].into_iter()),
1846            NumericOperationIr::IntRandom(_repr) => Box::new([].into_iter()),
1847            NumericOperationIr::Powf(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
1848            NumericOperationIr::CumMin(repr) => Box::new([&repr.out].into_iter()),
1849            NumericOperationIr::CumMax(repr) => Box::new([&repr.out].into_iter()),
1850            NumericOperationIr::CumProd(repr) => Box::new([&repr.out].into_iter()),
1851            NumericOperationIr::CumSum(repr) => Box::new([&repr.out].into_iter()),
1852        }
1853    }
1854
1855    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
1856        match self {
1857            NumericOperationIr::Add(repr) => Box::new([&repr.out].into_iter()),
1858            NumericOperationIr::AddScalar(repr) => Box::new([&repr.out].into_iter()),
1859            NumericOperationIr::Sub(repr) => Box::new([&repr.out].into_iter()),
1860            NumericOperationIr::SubScalar(repr) => Box::new([&repr.out].into_iter()),
1861            NumericOperationIr::Mul(repr) => Box::new([&repr.out].into_iter()),
1862            NumericOperationIr::MulScalar(repr) => Box::new([&repr.out].into_iter()),
1863            NumericOperationIr::Div(repr) => Box::new([&repr.out].into_iter()),
1864            NumericOperationIr::DivScalar(repr) => Box::new([&repr.out].into_iter()),
1865            NumericOperationIr::Rem(repr) => Box::new([&repr.out].into_iter()),
1866            NumericOperationIr::RemScalar(repr) => Box::new([&repr.out].into_iter()),
1867            NumericOperationIr::GreaterElem(repr) => Box::new([&repr.out].into_iter()),
1868            NumericOperationIr::GreaterEqualElem(repr) => Box::new([&repr.out].into_iter()),
1869            NumericOperationIr::LowerElem(repr) => Box::new([&repr.out].into_iter()),
1870            NumericOperationIr::LowerEqualElem(repr) => Box::new([&repr.out].into_iter()),
1871            NumericOperationIr::Greater(repr) => Box::new([&repr.out].into_iter()),
1872            NumericOperationIr::GreaterEqual(repr) => Box::new([&repr.out].into_iter()),
1873            NumericOperationIr::Lower(repr) => Box::new([&repr.out].into_iter()),
1874            NumericOperationIr::LowerEqual(repr) => Box::new([&repr.out].into_iter()),
1875            NumericOperationIr::ArgMax(repr) => Box::new([&repr.out].into_iter()),
1876            NumericOperationIr::ArgMin(repr) => Box::new([&repr.out].into_iter()),
1877            NumericOperationIr::Clamp(repr) => Box::new([&repr.out].into_iter()),
1878            NumericOperationIr::Abs(repr) => Box::new([&repr.out].into_iter()),
1879            NumericOperationIr::Full(repr) => Box::new([&repr.out].into_iter()),
1880            NumericOperationIr::MeanDim(repr) => Box::new([&repr.out].into_iter()),
1881            NumericOperationIr::Mean(repr) => Box::new([&repr.out].into_iter()),
1882            NumericOperationIr::Sum(repr) => Box::new([&repr.out].into_iter()),
1883            NumericOperationIr::SumDim(repr) => Box::new([&repr.out].into_iter()),
1884            NumericOperationIr::Prod(repr) => Box::new([&repr.out].into_iter()),
1885            NumericOperationIr::ProdDim(repr) => Box::new([&repr.out].into_iter()),
1886            NumericOperationIr::Max(repr) => Box::new([&repr.out].into_iter()),
1887            NumericOperationIr::MaxDimWithIndices(repr) => {
1888                Box::new([&repr.out, &repr.out_indices].into_iter())
1889            }
1890            NumericOperationIr::MinDimWithIndices(repr) => {
1891                Box::new([&repr.out, &repr.out_indices].into_iter())
1892            }
1893            NumericOperationIr::Min(repr) => Box::new([&repr.out].into_iter()),
1894            NumericOperationIr::MaxDim(repr) => Box::new([&repr.out].into_iter()),
1895            NumericOperationIr::MinDim(repr) => Box::new([&repr.out].into_iter()),
1896            NumericOperationIr::MaxAbs(repr) => Box::new([&repr.out].into_iter()),
1897            NumericOperationIr::MaxAbsDim(repr) => Box::new([&repr.out].into_iter()),
1898            NumericOperationIr::IntRandom(repr) => Box::new([&repr.out].into_iter()),
1899            NumericOperationIr::Powf(repr) => Box::new([&repr.out].into_iter()),
1900            NumericOperationIr::CumMin(repr) => Box::new([&repr.out].into_iter()),
1901            NumericOperationIr::CumMax(repr) => Box::new([&repr.out].into_iter()),
1902            NumericOperationIr::CumProd(repr) => Box::new([&repr.out].into_iter()),
1903            NumericOperationIr::CumSum(repr) => Box::new([&repr.out].into_iter()),
1904        }
1905    }
1906    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
1907        let mut output = Vec::new();
1908
1909        match self {
1910            NumericOperationIr::Add(repr) => {
1911                repr.lhs.mark_read_only(nodes, &mut output);
1912                repr.rhs.mark_read_only(nodes, &mut output);
1913            }
1914            NumericOperationIr::AddScalar(repr) => {
1915                repr.lhs.mark_read_only(nodes, &mut output);
1916            }
1917            NumericOperationIr::Sub(repr) => {
1918                repr.lhs.mark_read_only(nodes, &mut output);
1919                repr.rhs.mark_read_only(nodes, &mut output);
1920            }
1921            NumericOperationIr::SubScalar(repr) => {
1922                repr.lhs.mark_read_only(nodes, &mut output);
1923            }
1924            NumericOperationIr::Mul(repr) => {
1925                repr.lhs.mark_read_only(nodes, &mut output);
1926                repr.rhs.mark_read_only(nodes, &mut output);
1927            }
1928            NumericOperationIr::MulScalar(repr) => {
1929                repr.lhs.mark_read_only(nodes, &mut output);
1930            }
1931            NumericOperationIr::Div(repr) => {
1932                repr.lhs.mark_read_only(nodes, &mut output);
1933                repr.rhs.mark_read_only(nodes, &mut output);
1934            }
1935            NumericOperationIr::DivScalar(repr) => {
1936                repr.lhs.mark_read_only(nodes, &mut output);
1937            }
1938            NumericOperationIr::Rem(repr) => {
1939                repr.lhs.mark_read_only(nodes, &mut output);
1940                repr.rhs.mark_read_only(nodes, &mut output);
1941            }
1942            NumericOperationIr::RemScalar(repr) => {
1943                repr.lhs.mark_read_only(nodes, &mut output);
1944            }
1945            NumericOperationIr::GreaterElem(repr) => {
1946                repr.lhs.mark_read_only(nodes, &mut output);
1947            }
1948            NumericOperationIr::GreaterEqualElem(repr) => {
1949                repr.lhs.mark_read_only(nodes, &mut output);
1950            }
1951            NumericOperationIr::LowerElem(repr) => {
1952                repr.lhs.mark_read_only(nodes, &mut output);
1953            }
1954            NumericOperationIr::LowerEqualElem(repr) => {
1955                repr.lhs.mark_read_only(nodes, &mut output);
1956            }
1957            NumericOperationIr::Greater(repr) => {
1958                repr.lhs.mark_read_only(nodes, &mut output);
1959                repr.rhs.mark_read_only(nodes, &mut output);
1960            }
1961            NumericOperationIr::GreaterEqual(repr) => {
1962                repr.lhs.mark_read_only(nodes, &mut output);
1963                repr.rhs.mark_read_only(nodes, &mut output);
1964            }
1965            NumericOperationIr::Lower(repr) => {
1966                repr.lhs.mark_read_only(nodes, &mut output);
1967                repr.rhs.mark_read_only(nodes, &mut output);
1968            }
1969            NumericOperationIr::LowerEqual(repr) => {
1970                repr.lhs.mark_read_only(nodes, &mut output);
1971                repr.rhs.mark_read_only(nodes, &mut output);
1972            }
1973            NumericOperationIr::ArgMax(repr) => {
1974                repr.input.mark_read_only(nodes, &mut output);
1975            }
1976            NumericOperationIr::ArgMin(repr) => {
1977                repr.input.mark_read_only(nodes, &mut output);
1978            }
1979            NumericOperationIr::Clamp(repr) => {
1980                repr.tensor.mark_read_only(nodes, &mut output);
1981            }
1982            NumericOperationIr::Abs(repr) => {
1983                repr.input.mark_read_only(nodes, &mut output);
1984            }
1985            NumericOperationIr::Full(_) => {}
1986            NumericOperationIr::MeanDim(repr) => {
1987                repr.input.mark_read_only(nodes, &mut output);
1988            }
1989            NumericOperationIr::Mean(repr) => {
1990                repr.input.mark_read_only(nodes, &mut output);
1991            }
1992            NumericOperationIr::Sum(repr) => {
1993                repr.input.mark_read_only(nodes, &mut output);
1994            }
1995            NumericOperationIr::SumDim(repr) => {
1996                repr.input.mark_read_only(nodes, &mut output);
1997            }
1998            NumericOperationIr::Prod(repr) => {
1999                repr.input.mark_read_only(nodes, &mut output);
2000            }
2001            NumericOperationIr::ProdDim(repr) => {
2002                repr.input.mark_read_only(nodes, &mut output);
2003            }
2004            NumericOperationIr::Max(repr) => {
2005                repr.input.mark_read_only(nodes, &mut output);
2006            }
2007            NumericOperationIr::MaxDimWithIndices(repr) => {
2008                repr.tensor.mark_read_only(nodes, &mut output);
2009            }
2010            NumericOperationIr::MinDimWithIndices(repr) => {
2011                repr.tensor.mark_read_only(nodes, &mut output);
2012            }
2013            NumericOperationIr::Min(repr) => {
2014                repr.input.mark_read_only(nodes, &mut output);
2015            }
2016            NumericOperationIr::MaxDim(repr) => {
2017                repr.input.mark_read_only(nodes, &mut output);
2018            }
2019            NumericOperationIr::MinDim(repr) => {
2020                repr.input.mark_read_only(nodes, &mut output);
2021            }
2022            NumericOperationIr::MaxAbs(repr) => {
2023                repr.input.mark_read_only(nodes, &mut output);
2024            }
2025            NumericOperationIr::MaxAbsDim(repr) => {
2026                repr.input.mark_read_only(nodes, &mut output);
2027            }
2028            NumericOperationIr::IntRandom(_) => {}
2029            NumericOperationIr::Powf(repr) => {
2030                repr.lhs.mark_read_only(nodes, &mut output);
2031                repr.rhs.mark_read_only(nodes, &mut output);
2032            }
2033            NumericOperationIr::CumSum(repr) => {
2034                repr.input.mark_read_only(nodes, &mut output);
2035            }
2036            NumericOperationIr::CumProd(repr) => {
2037                repr.input.mark_read_only(nodes, &mut output);
2038            }
2039            NumericOperationIr::CumMin(repr) => {
2040                repr.input.mark_read_only(nodes, &mut output);
2041            }
2042            NumericOperationIr::CumMax(repr) => {
2043                repr.input.mark_read_only(nodes, &mut output);
2044            }
2045        };
2046
2047        output
2048    }
2049}
2050
2051impl FloatOperationIr {
2052    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2053        match self {
2054            FloatOperationIr::Matmul(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2055            FloatOperationIr::Cross(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2056            FloatOperationIr::Random(_repr) => Box::new([].into_iter()),
2057            FloatOperationIr::Exp(repr) => Box::new([&repr.input].into_iter()),
2058            FloatOperationIr::Log(repr) => Box::new([&repr.input].into_iter()),
2059            FloatOperationIr::Log1p(repr) => Box::new([&repr.input].into_iter()),
2060            FloatOperationIr::Erf(repr) => Box::new([&repr.input].into_iter()),
2061            FloatOperationIr::Recip(repr) => Box::new([&repr.input].into_iter()),
2062            FloatOperationIr::PowfScalar(repr) => Box::new([&repr.lhs].into_iter()),
2063            FloatOperationIr::Sqrt(repr) => Box::new([&repr.input].into_iter()),
2064            FloatOperationIr::Cos(repr) => Box::new([&repr.input].into_iter()),
2065            FloatOperationIr::Sin(repr) => Box::new([&repr.input].into_iter()),
2066            FloatOperationIr::Tanh(repr) => Box::new([&repr.input].into_iter()),
2067            FloatOperationIr::Round(repr) => Box::new([&repr.input].into_iter()),
2068            FloatOperationIr::Floor(repr) => Box::new([&repr.input].into_iter()),
2069            FloatOperationIr::Ceil(repr) => Box::new([&repr.input].into_iter()),
2070            FloatOperationIr::Trunc(repr) => Box::new([&repr.input].into_iter()),
2071            FloatOperationIr::IntoInt(repr) => Box::new([&repr.input].into_iter()),
2072            FloatOperationIr::Quantize(repr) => {
2073                Box::new([&repr.tensor, &repr.qparams.scales].into_iter())
2074            }
2075            FloatOperationIr::Dequantize(repr) => Box::new([&repr.input].into_iter()),
2076            FloatOperationIr::IsNan(repr) => Box::new([&repr.input].into_iter()),
2077            FloatOperationIr::IsInf(repr) => Box::new([&repr.input].into_iter()),
2078            FloatOperationIr::GridSample2d(repr) => {
2079                Box::new([&repr.tensor, &repr.grid].into_iter())
2080            }
2081            FloatOperationIr::Tan(repr) => Box::new([&repr.input].into_iter()),
2082            FloatOperationIr::Cosh(repr) => Box::new([&repr.input].into_iter()),
2083            FloatOperationIr::Sinh(repr) => Box::new([&repr.input].into_iter()),
2084            FloatOperationIr::ArcCos(repr) => Box::new([&repr.input].into_iter()),
2085            FloatOperationIr::ArcCosh(repr) => Box::new([&repr.input].into_iter()),
2086            FloatOperationIr::ArcSin(repr) => Box::new([&repr.input].into_iter()),
2087            FloatOperationIr::ArcSinh(repr) => Box::new([&repr.input].into_iter()),
2088            FloatOperationIr::ArcTan(repr) => Box::new([&repr.input].into_iter()),
2089            FloatOperationIr::ArcTanh(repr) => Box::new([&repr.input].into_iter()),
2090            FloatOperationIr::ArcTan2(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2091        }
2092    }
2093    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2094        match self {
2095            FloatOperationIr::Matmul(repr) => Box::new([&repr.out].into_iter()),
2096            FloatOperationIr::Cross(repr) => Box::new([&repr.out].into_iter()),
2097            FloatOperationIr::Random(repr) => Box::new([&repr.out].into_iter()),
2098            FloatOperationIr::Exp(repr) => Box::new([&repr.out].into_iter()),
2099            FloatOperationIr::Log(repr) => Box::new([&repr.out].into_iter()),
2100            FloatOperationIr::Log1p(repr) => Box::new([&repr.out].into_iter()),
2101            FloatOperationIr::Erf(repr) => Box::new([&repr.out].into_iter()),
2102            FloatOperationIr::Recip(repr) => Box::new([&repr.out].into_iter()),
2103            FloatOperationIr::PowfScalar(repr) => Box::new([&repr.out].into_iter()),
2104            FloatOperationIr::Sqrt(repr) => Box::new([&repr.out].into_iter()),
2105            FloatOperationIr::Cos(repr) => Box::new([&repr.out].into_iter()),
2106            FloatOperationIr::Sin(repr) => Box::new([&repr.out].into_iter()),
2107            FloatOperationIr::Tanh(repr) => Box::new([&repr.out].into_iter()),
2108            FloatOperationIr::Round(repr) => Box::new([&repr.out].into_iter()),
2109            FloatOperationIr::Floor(repr) => Box::new([&repr.out].into_iter()),
2110            FloatOperationIr::Ceil(repr) => Box::new([&repr.out].into_iter()),
2111            FloatOperationIr::Trunc(repr) => Box::new([&repr.out].into_iter()),
2112            FloatOperationIr::IntoInt(repr) => Box::new([&repr.out].into_iter()),
2113            FloatOperationIr::Quantize(repr) => Box::new([&repr.out].into_iter()),
2114            FloatOperationIr::Dequantize(repr) => Box::new([&repr.out].into_iter()),
2115            FloatOperationIr::IsNan(repr) => Box::new([&repr.out].into_iter()),
2116            FloatOperationIr::IsInf(repr) => Box::new([&repr.out].into_iter()),
2117            FloatOperationIr::GridSample2d(repr) => Box::new([&repr.out].into_iter()),
2118            FloatOperationIr::Tan(repr) => Box::new([&repr.out].into_iter()),
2119            FloatOperationIr::Cosh(repr) => Box::new([&repr.out].into_iter()),
2120            FloatOperationIr::Sinh(repr) => Box::new([&repr.out].into_iter()),
2121            FloatOperationIr::ArcCos(repr) => Box::new([&repr.out].into_iter()),
2122            FloatOperationIr::ArcCosh(repr) => Box::new([&repr.out].into_iter()),
2123            FloatOperationIr::ArcSin(repr) => Box::new([&repr.out].into_iter()),
2124            FloatOperationIr::ArcSinh(repr) => Box::new([&repr.out].into_iter()),
2125            FloatOperationIr::ArcTan(repr) => Box::new([&repr.out].into_iter()),
2126            FloatOperationIr::ArcTanh(repr) => Box::new([&repr.out].into_iter()),
2127            FloatOperationIr::ArcTan2(repr) => Box::new([&repr.out].into_iter()),
2128        }
2129    }
2130
2131    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2132        let mut output = Vec::new();
2133
2134        match self {
2135            FloatOperationIr::Matmul(repr) => {
2136                repr.lhs.mark_read_only(nodes, &mut output);
2137                repr.rhs.mark_read_only(nodes, &mut output);
2138            }
2139            FloatOperationIr::Cross(repr) => {
2140                repr.lhs.mark_read_only(nodes, &mut output);
2141                repr.rhs.mark_read_only(nodes, &mut output);
2142            }
2143            FloatOperationIr::Random(_) => {}
2144            FloatOperationIr::Exp(repr) => {
2145                repr.input.mark_read_only(nodes, &mut output);
2146            }
2147            FloatOperationIr::Log(repr) => {
2148                repr.input.mark_read_only(nodes, &mut output);
2149            }
2150            FloatOperationIr::Log1p(repr) => {
2151                repr.input.mark_read_only(nodes, &mut output);
2152            }
2153            FloatOperationIr::Erf(repr) => {
2154                repr.input.mark_read_only(nodes, &mut output);
2155            }
2156            FloatOperationIr::Recip(repr) => {
2157                repr.input.mark_read_only(nodes, &mut output);
2158            }
2159            FloatOperationIr::PowfScalar(repr) => {
2160                repr.lhs.mark_read_only(nodes, &mut output);
2161            }
2162            FloatOperationIr::Sqrt(repr) => {
2163                repr.input.mark_read_only(nodes, &mut output);
2164            }
2165            FloatOperationIr::Cos(repr) => {
2166                repr.input.mark_read_only(nodes, &mut output);
2167            }
2168            FloatOperationIr::Sin(repr) => {
2169                repr.input.mark_read_only(nodes, &mut output);
2170            }
2171            FloatOperationIr::Tanh(repr) => {
2172                repr.input.mark_read_only(nodes, &mut output);
2173            }
2174            FloatOperationIr::Round(repr) => {
2175                repr.input.mark_read_only(nodes, &mut output);
2176            }
2177            FloatOperationIr::Floor(repr) => {
2178                repr.input.mark_read_only(nodes, &mut output);
2179            }
2180            FloatOperationIr::Ceil(repr) => {
2181                repr.input.mark_read_only(nodes, &mut output);
2182            }
2183            FloatOperationIr::Trunc(repr) => {
2184                repr.input.mark_read_only(nodes, &mut output);
2185            }
2186            FloatOperationIr::Quantize(repr) => {
2187                repr.tensor.mark_read_only(nodes, &mut output);
2188                repr.qparams.scales.mark_read_only(nodes, &mut output);
2189            }
2190            FloatOperationIr::Dequantize(repr) => {
2191                repr.input.mark_read_only(nodes, &mut output);
2192            }
2193            FloatOperationIr::IntoInt(repr) => {
2194                repr.input.mark_read_only(nodes, &mut output);
2195            }
2196            FloatOperationIr::IsNan(repr) => {
2197                repr.input.mark_read_only(nodes, &mut output);
2198            }
2199            FloatOperationIr::IsInf(repr) => {
2200                repr.input.mark_read_only(nodes, &mut output);
2201            }
2202            FloatOperationIr::GridSample2d(repr) => {
2203                repr.tensor.mark_read_only(nodes, &mut output);
2204                repr.grid.mark_read_only(nodes, &mut output);
2205            }
2206            FloatOperationIr::Tan(repr) => repr.input.mark_read_only(nodes, &mut output),
2207            FloatOperationIr::Cosh(repr) => repr.input.mark_read_only(nodes, &mut output),
2208            FloatOperationIr::Sinh(repr) => repr.input.mark_read_only(nodes, &mut output),
2209            FloatOperationIr::ArcCos(repr) => repr.input.mark_read_only(nodes, &mut output),
2210            FloatOperationIr::ArcCosh(repr) => repr.input.mark_read_only(nodes, &mut output),
2211            FloatOperationIr::ArcSin(repr) => repr.input.mark_read_only(nodes, &mut output),
2212            FloatOperationIr::ArcSinh(repr) => repr.input.mark_read_only(nodes, &mut output),
2213            FloatOperationIr::ArcTan(repr) => repr.input.mark_read_only(nodes, &mut output),
2214            FloatOperationIr::ArcTanh(repr) => repr.input.mark_read_only(nodes, &mut output),
2215            FloatOperationIr::ArcTan2(repr) => {
2216                repr.lhs.mark_read_only(nodes, &mut output);
2217                repr.rhs.mark_read_only(nodes, &mut output);
2218            }
2219        };
2220
2221        output
2222    }
2223}
2224
2225impl IntOperationIr {
2226    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2227        match self {
2228            IntOperationIr::Matmul(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2229            IntOperationIr::IntoFloat(repr) => Box::new([&repr.input].into_iter()),
2230            IntOperationIr::BitwiseAnd(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2231            IntOperationIr::BitwiseAndScalar(repr) => Box::new([&repr.lhs].into_iter()),
2232            IntOperationIr::BitwiseOr(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2233            IntOperationIr::BitwiseOrScalar(repr) => Box::new([&repr.lhs].into_iter()),
2234            IntOperationIr::BitwiseXor(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2235            IntOperationIr::BitwiseXorScalar(repr) => Box::new([&repr.lhs].into_iter()),
2236            IntOperationIr::BitwiseNot(repr) => Box::new([&repr.input].into_iter()),
2237            IntOperationIr::BitwiseLeftShift(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2238            IntOperationIr::BitwiseLeftShiftScalar(repr) => Box::new([&repr.lhs].into_iter()),
2239            IntOperationIr::BitwiseRightShift(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2240            IntOperationIr::BitwiseRightShiftScalar(repr) => Box::new([&repr.lhs].into_iter()),
2241        }
2242    }
2243
2244    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2245        match self {
2246            IntOperationIr::Matmul(repr) => Box::new([&repr.out].into_iter()),
2247            IntOperationIr::IntoFloat(repr) => Box::new([&repr.out].into_iter()),
2248            IntOperationIr::BitwiseAnd(repr) => Box::new([&repr.out].into_iter()),
2249            IntOperationIr::BitwiseAndScalar(repr) => Box::new([&repr.out].into_iter()),
2250            IntOperationIr::BitwiseOr(repr) => Box::new([&repr.out].into_iter()),
2251            IntOperationIr::BitwiseOrScalar(repr) => Box::new([&repr.out].into_iter()),
2252            IntOperationIr::BitwiseXor(repr) => Box::new([&repr.out].into_iter()),
2253            IntOperationIr::BitwiseXorScalar(repr) => Box::new([&repr.out].into_iter()),
2254            IntOperationIr::BitwiseNot(repr) => Box::new([&repr.out].into_iter()),
2255            IntOperationIr::BitwiseLeftShift(repr) => Box::new([&repr.out].into_iter()),
2256            IntOperationIr::BitwiseLeftShiftScalar(repr) => Box::new([&repr.out].into_iter()),
2257            IntOperationIr::BitwiseRightShift(repr) => Box::new([&repr.out].into_iter()),
2258            IntOperationIr::BitwiseRightShiftScalar(repr) => Box::new([&repr.out].into_iter()),
2259        }
2260    }
2261
2262    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2263        let mut output = Vec::new();
2264
2265        match self {
2266            IntOperationIr::Matmul(repr) => {
2267                repr.lhs.mark_read_only(nodes, &mut output);
2268                repr.rhs.mark_read_only(nodes, &mut output);
2269            }
2270            IntOperationIr::IntoFloat(repr) => {
2271                repr.input.mark_read_only(nodes, &mut output);
2272            }
2273            IntOperationIr::BitwiseAnd(repr) => {
2274                repr.lhs.mark_read_only(nodes, &mut output);
2275                repr.rhs.mark_read_only(nodes, &mut output);
2276            }
2277            IntOperationIr::BitwiseAndScalar(repr) => {
2278                repr.lhs.mark_read_only(nodes, &mut output);
2279            }
2280            IntOperationIr::BitwiseOr(repr) => {
2281                repr.lhs.mark_read_only(nodes, &mut output);
2282                repr.rhs.mark_read_only(nodes, &mut output);
2283            }
2284            IntOperationIr::BitwiseOrScalar(repr) => {
2285                repr.lhs.mark_read_only(nodes, &mut output);
2286            }
2287            IntOperationIr::BitwiseXor(repr) => {
2288                repr.lhs.mark_read_only(nodes, &mut output);
2289                repr.rhs.mark_read_only(nodes, &mut output);
2290            }
2291            IntOperationIr::BitwiseXorScalar(repr) => {
2292                repr.lhs.mark_read_only(nodes, &mut output);
2293            }
2294            IntOperationIr::BitwiseNot(repr) => {
2295                repr.input.mark_read_only(nodes, &mut output);
2296            }
2297            IntOperationIr::BitwiseLeftShift(repr) => {
2298                repr.lhs.mark_read_only(nodes, &mut output);
2299                repr.rhs.mark_read_only(nodes, &mut output);
2300            }
2301            IntOperationIr::BitwiseLeftShiftScalar(repr) => {
2302                repr.lhs.mark_read_only(nodes, &mut output);
2303            }
2304            IntOperationIr::BitwiseRightShift(repr) => {
2305                repr.lhs.mark_read_only(nodes, &mut output);
2306                repr.rhs.mark_read_only(nodes, &mut output);
2307            }
2308            IntOperationIr::BitwiseRightShiftScalar(repr) => {
2309                repr.lhs.mark_read_only(nodes, &mut output);
2310            }
2311        };
2312
2313        output
2314    }
2315}
2316
2317impl BoolOperationIr {
2318    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2319        match self {
2320            BoolOperationIr::IntoFloat(repr) => Box::new([&repr.input].into_iter()),
2321            BoolOperationIr::IntoInt(repr) => Box::new([&repr.input].into_iter()),
2322            BoolOperationIr::Not(repr) => Box::new([&repr.input].into_iter()),
2323            BoolOperationIr::And(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2324            BoolOperationIr::Or(repr) => Box::new([&repr.lhs, &repr.rhs].into_iter()),
2325        }
2326    }
2327    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2328        match self {
2329            BoolOperationIr::IntoFloat(repr) => Box::new([&repr.out].into_iter()),
2330            BoolOperationIr::IntoInt(repr) => Box::new([&repr.out].into_iter()),
2331            BoolOperationIr::Not(repr) => Box::new([&repr.out].into_iter()),
2332            BoolOperationIr::And(repr) => Box::new([&repr.out].into_iter()),
2333            BoolOperationIr::Or(repr) => Box::new([&repr.out].into_iter()),
2334        }
2335    }
2336    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2337        let mut output = Vec::new();
2338
2339        match self {
2340            BoolOperationIr::IntoFloat(repr) => {
2341                repr.input.mark_read_only(nodes, &mut output);
2342            }
2343            BoolOperationIr::IntoInt(repr) => {
2344                repr.input.mark_read_only(nodes, &mut output);
2345            }
2346            BoolOperationIr::Not(repr) => {
2347                repr.input.mark_read_only(nodes, &mut output);
2348            }
2349            BoolOperationIr::And(repr) => {
2350                repr.lhs.mark_read_only(nodes, &mut output);
2351                repr.rhs.mark_read_only(nodes, &mut output);
2352            }
2353            BoolOperationIr::Or(repr) => {
2354                repr.lhs.mark_read_only(nodes, &mut output);
2355                repr.rhs.mark_read_only(nodes, &mut output);
2356            }
2357        };
2358
2359        output
2360    }
2361}
2362
2363impl ModuleOperationIr {
2364    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2365        match self {
2366            ModuleOperationIr::Embedding(repr) => {
2367                Box::new([&repr.weights, &repr.indices].into_iter())
2368            }
2369            ModuleOperationIr::EmbeddingBackward(repr) => {
2370                Box::new([&repr.weights, &repr.out_grad, &repr.indices].into_iter())
2371            }
2372            ModuleOperationIr::Conv1d(repr) => {
2373                if let Some(bias) = &repr.bias {
2374                    Box::new([&repr.x, &repr.weight, bias].into_iter())
2375                } else {
2376                    Box::new([&repr.x, &repr.weight].into_iter())
2377                }
2378            }
2379            ModuleOperationIr::Conv2d(repr) => {
2380                if let Some(bias) = &repr.bias {
2381                    Box::new([&repr.x, &repr.weight, bias].into_iter())
2382                } else {
2383                    Box::new([&repr.x, &repr.weight].into_iter())
2384                }
2385            }
2386            ModuleOperationIr::Conv3d(repr) => {
2387                if let Some(bias) = &repr.bias {
2388                    Box::new([&repr.x, &repr.weight, bias].into_iter())
2389                } else {
2390                    Box::new([&repr.x, &repr.weight].into_iter())
2391                }
2392            }
2393            ModuleOperationIr::DeformableConv2d(repr) => match (&repr.mask, &repr.bias) {
2394                (Some(mask), Some(bias)) => {
2395                    Box::new([&repr.x, &repr.offset, &repr.weight, mask, bias].into_iter())
2396                }
2397                (Some(mask), None) => {
2398                    Box::new([&repr.x, &repr.offset, &repr.weight, mask].into_iter())
2399                }
2400                (None, Some(bias)) => {
2401                    Box::new([&repr.x, &repr.offset, &repr.weight, bias].into_iter())
2402                }
2403                (None, None) => Box::new([&repr.x, &repr.offset, &repr.weight].into_iter()),
2404            },
2405            ModuleOperationIr::DeformableConv2dBackward(repr) => match (&repr.mask, &repr.bias) {
2406                (Some(mask), Some(bias)) => Box::new(
2407                    [
2408                        &repr.x,
2409                        &repr.offset,
2410                        &repr.weight,
2411                        &repr.out_grad,
2412                        mask,
2413                        bias,
2414                    ]
2415                    .into_iter(),
2416                ),
2417                (Some(mask), None) => Box::new(
2418                    [&repr.x, &repr.offset, &repr.weight, &repr.out_grad, mask].into_iter(),
2419                ),
2420                (None, Some(bias)) => Box::new(
2421                    [&repr.x, &repr.offset, &repr.weight, &repr.out_grad, bias].into_iter(),
2422                ),
2423                (None, None) => {
2424                    Box::new([&repr.x, &repr.offset, &repr.weight, &repr.out_grad].into_iter())
2425                }
2426            },
2427            ModuleOperationIr::ConvTranspose1d(repr) => {
2428                if let Some(bias) = &repr.bias {
2429                    Box::new([&repr.x, &repr.weight, bias].into_iter())
2430                } else {
2431                    Box::new([&repr.x, &repr.weight].into_iter())
2432                }
2433            }
2434            ModuleOperationIr::ConvTranspose2d(repr) => {
2435                if let Some(bias) = &repr.bias {
2436                    Box::new([&repr.x, &repr.weight, bias].into_iter())
2437                } else {
2438                    Box::new([&repr.x, &repr.weight].into_iter())
2439                }
2440            }
2441            ModuleOperationIr::ConvTranspose3d(repr) => {
2442                if let Some(bias) = &repr.bias {
2443                    Box::new([&repr.x, &repr.weight, bias].into_iter())
2444                } else {
2445                    Box::new([&repr.x, &repr.weight].into_iter())
2446                }
2447            }
2448            ModuleOperationIr::AvgPool1d(repr) => Box::new([&repr.x].into_iter()),
2449            ModuleOperationIr::AvgPool2d(repr) => Box::new([&repr.x].into_iter()),
2450            ModuleOperationIr::AvgPool1dBackward(repr) => {
2451                Box::new([&repr.x, &repr.grad].into_iter())
2452            }
2453            ModuleOperationIr::AvgPool2dBackward(repr) => {
2454                Box::new([&repr.x, &repr.grad].into_iter())
2455            }
2456            ModuleOperationIr::AdaptiveAvgPool1d(repr) => Box::new([&repr.x].into_iter()),
2457            ModuleOperationIr::AdaptiveAvgPool2d(repr) => Box::new([&repr.x].into_iter()),
2458            ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => {
2459                Box::new([&repr.x, &repr.grad].into_iter())
2460            }
2461            ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => {
2462                Box::new([&repr.x, &repr.grad].into_iter())
2463            }
2464            ModuleOperationIr::MaxPool1d(repr) => Box::new([&repr.x].into_iter()),
2465            ModuleOperationIr::MaxPool1dWithIndices(repr) => Box::new([&repr.x].into_iter()),
2466            ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => {
2467                Box::new([&repr.x, &repr.indices, &repr.grad].into_iter())
2468            }
2469            ModuleOperationIr::MaxPool2d(repr) => Box::new([&repr.x].into_iter()),
2470            ModuleOperationIr::MaxPool2dWithIndices(repr) => Box::new([&repr.x].into_iter()),
2471            ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => {
2472                Box::new([&repr.x, &repr.indices, &repr.grad].into_iter())
2473            }
2474            ModuleOperationIr::Interpolate(repr) => Box::new([&repr.x].into_iter()),
2475            ModuleOperationIr::InterpolateBackward(repr) => {
2476                Box::new([&repr.x, &repr.grad].into_iter())
2477            }
2478        }
2479    }
2480    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2481        match self {
2482            ModuleOperationIr::Embedding(repr) => Box::new([&repr.out].into_iter()),
2483            ModuleOperationIr::EmbeddingBackward(repr) => Box::new([&repr.out].into_iter()),
2484            ModuleOperationIr::Conv1d(repr) => Box::new([&repr.out].into_iter()),
2485            ModuleOperationIr::Conv2d(repr) => Box::new([&repr.out].into_iter()),
2486            ModuleOperationIr::Conv3d(repr) => Box::new([&repr.out].into_iter()),
2487            ModuleOperationIr::DeformableConv2d(repr) => Box::new([&repr.out].into_iter()),
2488            ModuleOperationIr::DeformableConv2dBackward(repr) => {
2489                match (&repr.mask_grad, &repr.bias_grad) {
2490                    (Some(mask_grad), Some(bias_grad)) => Box::new(
2491                        [
2492                            &repr.input_grad,
2493                            &repr.offset_grad,
2494                            &repr.weight_grad,
2495                            mask_grad,
2496                            bias_grad,
2497                        ]
2498                        .into_iter(),
2499                    ),
2500                    (Some(mask_grad), None) => Box::new(
2501                        [
2502                            &repr.input_grad,
2503                            &repr.offset_grad,
2504                            &repr.weight_grad,
2505                            mask_grad,
2506                        ]
2507                        .into_iter(),
2508                    ),
2509                    (None, Some(bias_grad)) => Box::new(
2510                        [
2511                            &repr.input_grad,
2512                            &repr.offset_grad,
2513                            &repr.weight_grad,
2514                            bias_grad,
2515                        ]
2516                        .into_iter(),
2517                    ),
2518                    (None, None) => Box::new(
2519                        [&repr.input_grad, &repr.offset_grad, &repr.weight_grad].into_iter(),
2520                    ),
2521                }
2522            }
2523            ModuleOperationIr::ConvTranspose1d(repr) => Box::new([&repr.out].into_iter()),
2524            ModuleOperationIr::ConvTranspose2d(repr) => Box::new([&repr.out].into_iter()),
2525            ModuleOperationIr::ConvTranspose3d(repr) => Box::new([&repr.out].into_iter()),
2526            ModuleOperationIr::AvgPool1d(repr) => Box::new([&repr.out].into_iter()),
2527            ModuleOperationIr::AvgPool2d(repr) => Box::new([&repr.out].into_iter()),
2528            ModuleOperationIr::AvgPool1dBackward(repr) => Box::new([&repr.out].into_iter()),
2529            ModuleOperationIr::AvgPool2dBackward(repr) => Box::new([&repr.out].into_iter()),
2530            ModuleOperationIr::AdaptiveAvgPool1d(repr) => Box::new([&repr.out].into_iter()),
2531            ModuleOperationIr::AdaptiveAvgPool2d(repr) => Box::new([&repr.out].into_iter()),
2532            ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => Box::new([&repr.out].into_iter()),
2533            ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => Box::new([&repr.out].into_iter()),
2534            ModuleOperationIr::MaxPool1d(repr) => Box::new([&repr.out].into_iter()),
2535            ModuleOperationIr::MaxPool1dWithIndices(repr) => {
2536                Box::new([&repr.out, &repr.out_indices].into_iter())
2537            }
2538            ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => {
2539                Box::new([&repr.out].into_iter())
2540            }
2541            ModuleOperationIr::MaxPool2d(repr) => Box::new([&repr.out].into_iter()),
2542            ModuleOperationIr::MaxPool2dWithIndices(repr) => {
2543                Box::new([&repr.out, &repr.out_indices].into_iter())
2544            }
2545            ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => {
2546                Box::new([&repr.out].into_iter())
2547            }
2548            ModuleOperationIr::Interpolate(repr) => Box::new([&repr.out].into_iter()),
2549            ModuleOperationIr::InterpolateBackward(repr) => Box::new([&repr.out].into_iter()),
2550        }
2551    }
2552
2553    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2554        let mut output = Vec::new();
2555
2556        match self {
2557            ModuleOperationIr::Embedding(repr) => {
2558                repr.weights.mark_read_only(nodes, &mut output);
2559                repr.indices.mark_read_only(nodes, &mut output);
2560            }
2561            ModuleOperationIr::EmbeddingBackward(repr) => {
2562                repr.weights.mark_read_only(nodes, &mut output);
2563                repr.out_grad.mark_read_only(nodes, &mut output);
2564                repr.indices.mark_read_only(nodes, &mut output);
2565            }
2566            ModuleOperationIr::Conv1d(repr) => {
2567                repr.x.mark_read_only(nodes, &mut output);
2568                repr.weight.mark_read_only(nodes, &mut output);
2569
2570                if let Some(bias) = &mut repr.bias {
2571                    bias.mark_read_only(nodes, &mut output);
2572                }
2573            }
2574            ModuleOperationIr::Conv2d(repr) => {
2575                repr.x.mark_read_only(nodes, &mut output);
2576                repr.weight.mark_read_only(nodes, &mut output);
2577
2578                if let Some(bias) = &mut repr.bias {
2579                    bias.mark_read_only(nodes, &mut output);
2580                }
2581            }
2582            ModuleOperationIr::Conv3d(repr) => {
2583                repr.x.mark_read_only(nodes, &mut output);
2584                repr.weight.mark_read_only(nodes, &mut output);
2585
2586                if let Some(bias) = &mut repr.bias {
2587                    bias.mark_read_only(nodes, &mut output);
2588                }
2589            }
2590            ModuleOperationIr::DeformableConv2d(repr) => {
2591                repr.x.mark_read_only(nodes, &mut output);
2592                repr.weight.mark_read_only(nodes, &mut output);
2593                repr.offset.mark_read_only(nodes, &mut output);
2594
2595                match (&mut repr.mask, &mut repr.bias) {
2596                    (Some(mask), Some(bias)) => {
2597                        mask.mark_read_only(nodes, &mut output);
2598                        bias.mark_read_only(nodes, &mut output);
2599                    }
2600                    (Some(mask), None) => {
2601                        mask.mark_read_only(nodes, &mut output);
2602                    }
2603                    (None, Some(bias)) => {
2604                        bias.mark_read_only(nodes, &mut output);
2605                    }
2606                    (None, None) => {}
2607                };
2608            }
2609            ModuleOperationIr::DeformableConv2dBackward(repr) => {
2610                repr.x.mark_read_only(nodes, &mut output);
2611                repr.weight.mark_read_only(nodes, &mut output);
2612                repr.offset.mark_read_only(nodes, &mut output);
2613                repr.out_grad.mark_read_only(nodes, &mut output);
2614
2615                if let Some(mask) = repr.mask.as_mut() {
2616                    mask.mark_read_only(nodes, &mut output);
2617                }
2618                if let Some(bias) = repr.bias.as_mut() {
2619                    bias.mark_read_only(nodes, &mut output);
2620                }
2621            }
2622            ModuleOperationIr::ConvTranspose1d(repr) => {
2623                repr.x.mark_read_only(nodes, &mut output);
2624                repr.weight.mark_read_only(nodes, &mut output);
2625
2626                if let Some(bias) = &mut repr.bias {
2627                    bias.mark_read_only(nodes, &mut output);
2628                }
2629            }
2630            ModuleOperationIr::ConvTranspose2d(repr) => {
2631                repr.x.mark_read_only(nodes, &mut output);
2632                repr.weight.mark_read_only(nodes, &mut output);
2633
2634                if let Some(bias) = &mut repr.bias {
2635                    bias.mark_read_only(nodes, &mut output);
2636                }
2637            }
2638            ModuleOperationIr::ConvTranspose3d(repr) => {
2639                repr.x.mark_read_only(nodes, &mut output);
2640                repr.weight.mark_read_only(nodes, &mut output);
2641
2642                if let Some(bias) = &mut repr.bias {
2643                    bias.mark_read_only(nodes, &mut output);
2644                }
2645            }
2646            ModuleOperationIr::AvgPool1d(repr) => {
2647                repr.x.mark_read_only(nodes, &mut output);
2648            }
2649            ModuleOperationIr::AvgPool2d(repr) => {
2650                repr.x.mark_read_only(nodes, &mut output);
2651            }
2652            ModuleOperationIr::AvgPool1dBackward(repr) => {
2653                repr.x.mark_read_only(nodes, &mut output);
2654                repr.grad.mark_read_only(nodes, &mut output);
2655            }
2656            ModuleOperationIr::AvgPool2dBackward(repr) => {
2657                repr.x.mark_read_only(nodes, &mut output);
2658                repr.grad.mark_read_only(nodes, &mut output);
2659            }
2660            ModuleOperationIr::AdaptiveAvgPool1d(repr) => {
2661                repr.x.mark_read_only(nodes, &mut output);
2662            }
2663            ModuleOperationIr::AdaptiveAvgPool2d(repr) => {
2664                repr.x.mark_read_only(nodes, &mut output);
2665            }
2666            ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => {
2667                repr.x.mark_read_only(nodes, &mut output);
2668                repr.grad.mark_read_only(nodes, &mut output);
2669            }
2670            ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => {
2671                repr.x.mark_read_only(nodes, &mut output);
2672                repr.grad.mark_read_only(nodes, &mut output);
2673            }
2674            ModuleOperationIr::MaxPool1d(repr) => {
2675                repr.x.mark_read_only(nodes, &mut output);
2676            }
2677            ModuleOperationIr::MaxPool1dWithIndices(repr) => {
2678                repr.x.mark_read_only(nodes, &mut output);
2679            }
2680            ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => {
2681                repr.x.mark_read_only(nodes, &mut output);
2682                repr.grad.mark_read_only(nodes, &mut output);
2683            }
2684            ModuleOperationIr::MaxPool2d(repr) => {
2685                repr.x.mark_read_only(nodes, &mut output);
2686            }
2687            ModuleOperationIr::MaxPool2dWithIndices(repr) => {
2688                repr.x.mark_read_only(nodes, &mut output);
2689            }
2690            ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => {
2691                repr.x.mark_read_only(nodes, &mut output);
2692                repr.grad.mark_read_only(nodes, &mut output);
2693            }
2694            ModuleOperationIr::Interpolate(repr) => {
2695                repr.x.mark_read_only(nodes, &mut output);
2696            }
2697            ModuleOperationIr::InterpolateBackward(repr) => {
2698                repr.x.mark_read_only(nodes, &mut output);
2699                repr.grad.mark_read_only(nodes, &mut output);
2700            }
2701        };
2702
2703        output
2704    }
2705}
2706
2707impl InitOperationIr {
2708    fn inputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2709        Box::new([].into_iter())
2710    }
2711    fn outputs(&self) -> Box<dyn Iterator<Item = &TensorIr> + '_> {
2712        Box::new([&self.out].into_iter())
2713    }
2714}
2715
2716impl TensorIr {
2717    fn mark_read_only(&mut self, nodes: &[TensorId], output: &mut Vec<TensorIr>) {
2718        if self.status == TensorStatus::ReadWrite && nodes.contains(&self.id) {
2719            output.push(self.clone());
2720            self.status = TensorStatus::ReadOnly;
2721        }
2722    }
2723}
2724
2725impl core::hash::Hash for RandomOpIr {
2726    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
2727        self.out.hash(state);
2728
2729        match self.distribution {
2730            Distribution::Default => 1u8.hash(state),
2731            Distribution::Bernoulli(_) => 2u8.hash(state),
2732            Distribution::Uniform(_, _) => 3u8.hash(state),
2733            Distribution::Normal(_, _) => 4u8.hash(state),
2734        }
2735    }
2736}
2737
2738/// Extension trait to extract outputs when registering an operation.
2739pub trait OperationOutput<O> {
2740    /// Extract a single output.
2741    fn output(self) -> O;
2742
2743    /// Extract a fixed number of outputs.
2744    fn outputs<const N: usize>(self) -> [O; N];
2745}
2746
2747impl<O: core::fmt::Debug> OperationOutput<O> for Vec<O> {
2748    fn output(self) -> O {
2749        let [tensor] = self.outputs();
2750        tensor
2751    }
2752
2753    fn outputs<const N: usize>(self) -> [O; N] {
2754        self.try_into().unwrap()
2755    }
2756}