burn_ir/
operation.rs

1use burn_tensor::Shape;
2use core::hash::Hash;
3use serde::{Deserialize, Serialize};
4
5use alloc::borrow::ToOwned;
6use alloc::boxed::Box;
7use alloc::{string::String, vec, vec::Vec};
8
9use burn_tensor::{
10    DType, Distribution, Slice,
11    ops::{
12        ConvOptions, ConvTransposeOptions, DeformConvOptions, InterpolateMode, InterpolateOptions,
13    },
14    quantization::QuantScheme,
15};
16
17use crate::{ScalarIr, TensorId, TensorIr, TensorStatus};
18
19/// Custom operation in fusion stream, declaring its inputs and outputs.
20#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
21pub struct CustomOpIr {
22    /// Unique identifier of the operation.
23    pub id: String,
24    /// Input tensors used in the custom operation.
25    pub inputs: Vec<TensorIr>,
26    /// Output tensors used in the custom operation.
27    pub outputs: Vec<TensorIr>,
28}
29
30impl CustomOpIr {
31    /// Create a new custom operation intermediate representation.
32    pub fn new(id: &'static str, inputs: &[TensorIr], outputs: &[TensorIr]) -> Self {
33        Self {
34            id: id.to_owned(),
35            inputs: inputs.to_vec(),
36            outputs: outputs.to_vec(),
37        }
38    }
39
40    /// Cast the intermediate representation, and get the in and output tensors.
41    pub fn as_fixed<const N_IN: usize, const N_OUT: usize>(
42        &self,
43    ) -> (&[TensorIr; N_IN], &[TensorIr; N_OUT]) {
44        (
45            self.inputs.as_slice().try_into().expect(
46                "Wrong number of inputs expected (expected {D}, is {}), check your implementation",
47            ),
48            self.outputs.as_slice().try_into().expect(
49                "Wrong number of outputs expected (expected {D}, is {}), check your implementation",
50            ),
51        )
52    }
53
54    fn nodes(&self) -> Vec<&TensorIr> {
55        self.inputs.iter().chain(self.outputs.iter()).collect()
56    }
57}
58
59/// Describe all tensor operations possible.
60#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
61pub enum OperationIr {
62    /// Basic operation on a float tensor.
63    BaseFloat(BaseOperationIr),
64    /// Basic operation on an int tensor.
65    BaseInt(BaseOperationIr),
66    /// Basic operation on a bool tensor.
67    BaseBool(BaseOperationIr),
68    /// Numeric operation on a float tensor.
69    NumericFloat(DType, NumericOperationIr),
70    /// Numeric operation on an int tensor.
71    NumericInt(DType, NumericOperationIr),
72    /// Operation specific to a bool tensor.
73    Bool(BoolOperationIr),
74    /// Operation specific to an int tensor.
75    Int(IntOperationIr),
76    /// Operation specific to a float tensor.
77    Float(DType, FloatOperationIr),
78    /// Module operation.
79    Module(ModuleOperationIr),
80    /// Initialize operation.
81    Init(InitOperationIr),
82    /// A custom operation.
83    Custom(CustomOpIr),
84    /// A tensor is dropped.
85    Drop(TensorIr),
86}
87
88/// Operation intermediate representation specific to a float tensor.
89#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
90pub enum FloatOperationIr {
91    /// Operation corresponding to [exp](burn_tensor::ops::FloatTensorOps::float_exp).
92    Exp(UnaryOpIr),
93    /// Operation corresponding to [log](burn_tensor::ops::FloatTensorOps::float_log).
94    Log(UnaryOpIr),
95    /// Operation corresponding to [log1p](burn_tensor::ops::FloatTensorOps::float_log1p).
96    Log1p(UnaryOpIr),
97    /// Operation corresponding to [erf](burn_tensor::ops::FloatTensorOps::float_erf).
98    Erf(UnaryOpIr),
99    /// Operation corresponding to [powf_scalar](burn_tensor::ops::FloatTensorOps::float_powf_scalar).
100    PowfScalar(ScalarOpIr),
101    /// Operation corresponding to [sqrt](burn_tensor::ops::FloatTensorOps::float_sqrt).
102    Sqrt(UnaryOpIr),
103    /// Operation corresponding to [cos](burn_tensor::ops::FloatTensorOps::float_cos).
104    Cos(UnaryOpIr),
105    /// Operation corresponding to [sin](burn_tensor::ops::FloatTensorOps::float_sin).
106    Sin(UnaryOpIr),
107    /// Operation corresponding to [tanh](burn_tensor::ops::FloatTensorOps::float_tanh).
108    Tanh(UnaryOpIr),
109    /// Operation corresponding to [round](burn_tensor::ops::FloatTensorOps::float_round).
110    Round(UnaryOpIr),
111    /// Operation corresponding to [floor](burn_tensor::ops::FloatTensorOps::float_floor).
112    Floor(UnaryOpIr),
113    /// Operation corresponding to [ceil](burn_tensor::ops::FloatTensorOps::float_ceil).
114    Ceil(UnaryOpIr),
115    /// Operation corresponding to [trunc](burn_tensor::ops::FloatTensorOps::float_trunc).
116    Trunc(UnaryOpIr),
117    /// Operation corresponding to [into_int](burn_tensor::ops::FloatTensorOps::float_into_int).
118    IntoInt(UnaryOpIr),
119    /// Operation corresponding to [matmul](burn_tensor::ops::FloatTensorOps::float_matmul).
120    Matmul(BinaryOpIr),
121    /// Operation corresponding to [cross](burn_tensor::ops::FloatTensorOps::float_cross).
122    Cross(CrossOpIr),
123    /// Operation corresponding to [random](burn_tensor::ops::FloatTensorOps::float_random).
124    Random(RandomOpIr),
125    /// Operation corresponding to [recip](burn_tensor::ops::FloatTensorOps::float_recip).
126    Recip(UnaryOpIr),
127    /// Operation corresponding to [is_nan](burn_tensor::ops::FloatTensorOps::float_is_nan).
128    IsNan(UnaryOpIr),
129    /// Operation corresponding to [is_nan](burn_tensor::ops::FloatTensorOps::float_is_inf).
130    IsInf(UnaryOpIr),
131    /// Operation corresponding to [quantize](burn_tensor::ops::QTensorOps::quantize).
132    Quantize(QuantizeOpIr),
133    /// Operation corresponding to [dequantize](burn_tensor::ops::QTensorOps::dequantize).
134    Dequantize(DequantizeOpIr),
135}
136
137/// Operation intermediate representation specific to module.
138#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
139pub enum ModuleOperationIr {
140    /// Operation corresponding to [embedding](burn_tensor::ops::ModuleOps::embedding).
141    Embedding(EmbeddingOpIr),
142    /// Operation corresponding to [embedding_backward](burn_tensor::ops::ModuleOps::embedding_backward).
143    EmbeddingBackward(EmbeddingBackwardOpIr),
144    /// Operation corresponding to [conv1d](burn_tensor::ops::ModuleOps::conv1d).
145    Conv1d(Conv1dOpIr),
146    /// Operation corresponding to [conv2d](burn_tensor::ops::ModuleOps::conv2d).
147    Conv2d(Conv2dOpIr),
148    /// Operation corresponding to [conv3d](burn_tensor::ops::ModuleOps::conv3d).
149    Conv3d(Conv3dOpIr),
150    /// Operation corresponding to [deform_conv2d](burn_tensor::ops::ModuleOps::deform_conv2d)
151    DeformableConv2d(Box<DeformConv2dOpIr>),
152    /// Operation corresponding to [deform_conv2d_backward](burn_tensor::ops::ModuleOps::deform_conv2d_backward)
153    DeformableConv2dBackward(Box<DeformConv2dBackwardOpIr>),
154    /// Operation corresponding to [conv transpose 1d](burn_tensor::ops::ModuleOps::conv_transpose1d).
155    ConvTranspose1d(ConvTranspose1dOpIr),
156    /// Operation corresponding to [conv transpose 2d](burn_tensor::ops::ModuleOps::conv_transpose2d).
157    ConvTranspose2d(ConvTranspose2dOpIr),
158    /// Operation corresponding to [conv transpose 3d](burn_tensor::ops::ModuleOps::conv_transpose3d).
159    ConvTranspose3d(ConvTranspose3dOpIr),
160    /// Operation corresponding to [avg pool 1d](burn_tensor::ops::ModuleOps::avg_pool1d).
161    AvgPool1d(AvgPool1dOpIr),
162    /// Operation corresponding to [avg pool 2d](burn_tensor::ops::ModuleOps::avg_pool2d).
163    AvgPool2d(AvgPool2dOpIr),
164    /// Operation corresponding to
165    /// [avg pool 1d backward](burn_tensor::ops::ModuleOps::avg_pool1d_backward).
166    AvgPool1dBackward(AvgPool1dBackwardOpIr),
167    /// Operation corresponding to
168    /// [avg pool 2d backward](burn_tensor::ops::ModuleOps::avg_pool2d_backward).
169    AvgPool2dBackward(AvgPool2dBackwardOpIr),
170    /// Operation corresponding to
171    /// [adaptive avg pool 1d](burn_tensor::ops::ModuleOps::adaptive_avg_pool1d).
172    AdaptiveAvgPool1d(AdaptiveAvgPool1dOpIr),
173    /// Operation corresponding to
174    /// [adaptive avg pool 2d](burn_tensor::ops::ModuleOps::adaptive_avg_pool2d).
175    AdaptiveAvgPool2d(AdaptiveAvgPool2dOpIr),
176    /// Operation corresponding to
177    /// [adaptive avg pool 1d backward](burn_tensor::ops::ModuleOps::adaptive_avg_pool1d_backward).
178    AdaptiveAvgPool1dBackward(AdaptiveAvgPool1dBackwardOpIr),
179    /// Operation corresponding to
180    /// [adaptive avg pool 2d backward](burn_tensor::ops::ModuleOps::adaptive_avg_pool2d_backward).
181    AdaptiveAvgPool2dBackward(AdaptiveAvgPool2dBackwardOpIr),
182    /// Operation corresponding to
183    /// [max pool 1d](burn_tensor::ops::ModuleOps::max_pool1d).
184    MaxPool1d(MaxPool1dOpIr),
185    /// Operation corresponding to
186    /// [max pool 1d with indices](burn_tensor::ops::ModuleOps::max_pool1d_with_indices).
187    MaxPool1dWithIndices(MaxPool1dWithIndicesOpIr),
188    /// Operation corresponding to
189    /// [max pool 1d with indices backward](burn_tensor::ops::ModuleOps::max_pool1d_with_indices_backward).
190    MaxPool1dWithIndicesBackward(MaxPool1dWithIndicesBackwardOpIr),
191    /// Operation corresponding to
192    /// [max pool 2d](burn_tensor::ops::ModuleOps::max_pool1d).
193    MaxPool2d(MaxPool2dOpIr),
194    /// Operation corresponding to
195    /// [max pool 2d with indices](burn_tensor::ops::ModuleOps::max_pool2d_with_indices).
196    MaxPool2dWithIndices(MaxPool2dWithIndicesOpIr),
197    /// Operation corresponding to
198    /// [max pool 2d with indices backward](burn_tensor::ops::ModuleOps::max_pool2d_with_indices_backward).
199    MaxPool2dWithIndicesBackward(MaxPool2dWithIndicesBackwardOpIr),
200    /// Operation corresponding to [interpolate](burn_tensor::ops::ModuleOps::interpolate).
201    Interpolate(InterpolateOpIr),
202    /// Operation corresponding to [interpolate backward](burn_tensor::ops::ModuleOps::interpolate_backward).
203    InterpolateBackward(InterpolateBackwardOpIr),
204}
205
206/// Basic operations that can be done on any tensor type.
207#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
208pub enum BaseOperationIr {
209    /// Operation corresponding to:
210    ///
211    /// Float => [to device](burn_tensor::ops::FloatTensorOps::float_to_device).
212    /// Int => [to device](burn_tensor::ops::IntTensorOps::int_to_device).
213    /// Bool => [to device](burn_tensor::ops::BoolTensorOps::bool_to_device).
214    ToDevice(TensorIr),
215    /// Operation corresponding to:
216    ///
217    /// Float => [reshape](burn_tensor::ops::FloatTensorOps::float_reshape).
218    /// Int => [reshape](burn_tensor::ops::IntTensorOps::int_reshape).
219    /// Bool => [reshape](burn_tensor::ops::BoolTensorOps::bool_reshape).
220    Reshape(UnaryOpIr),
221
222    /// Operation corresponding to:
223    ///
224    /// Float => [swap_dims](burn_tensor::ops::FloatTensorOps::float_swap_dims).
225    /// Int => [swap_dims](burn_tensor::ops::IntTensorOps::int_swap_dims).
226    /// Bool => [swap_dims](burn_tensor::ops::BoolTensorOps::bool_swap_dims).
227    SwapDims(SwapDimsOpIr),
228
229    /// Operation corresponding to:
230    ///
231    /// Float => [permute](burn_tensor::ops::FloatTensorOps::float_permute).
232    /// Int => [permute](burn_tensor::ops::IntTensorOps::int_permute).
233    /// Bool => [permute](burn_tensor::ops::BoolTensorOps::bool_permute).
234    Permute(PermuteOpIr),
235
236    /// Operation corresponding to:
237    /// Float => [flip](burn_tensor::ops::FloatTensorOps::float_flip).
238    /// Int => [flip](burn_tensor::ops::IntTensorOps::int_flip).
239    /// Bool => [flip](burn_tensor::ops::BoolTensorOps::bool_flip).
240    Flip(FlipOpIr),
241
242    /// Operation corresponding to:
243    ///
244    /// Float => [expand](burn_tensor::ops::FloatTensorOps::float_expand).
245    /// Int => [expand](burn_tensor::ops::IntTensorOps::int_expand).
246    /// Bool => [expand](burn_tensor::ops::BoolTensorOps::bool_expand).
247    Expand(ExpandOpIr),
248
249    /// Unfold windows along an axis.
250    ///
251    Unfold(UnfoldOpIr),
252
253    /// Operation corresponding to:
254    ///
255    /// Float => [slice](burn_tensor::ops::FloatTensorOps::float_slice).
256    /// Int => [slice](burn_tensor::ops::IntTensorOps::int_slice).
257    /// Bool => [slice](burn_tensor::ops::BoolTensorOps::bool_slice).
258    Slice(SliceOpIr),
259    /// Operation corresponding to:
260    ///
261    /// Float => [slice assign](burn_tensor::ops::FloatTensorOps::float_slice_assign).
262    /// Int => [slice assign](burn_tensor::ops::IntTensorOps::int_slice_assign).
263    /// Bool => [slice assign](burn_tensor::ops::BoolTensorOps::bool_slice_assign).
264    SliceAssign(SliceAssignOpIr),
265    /// Operation corresponding to:
266    ///
267    /// Float => [equal](burn_tensor::ops::FloatTensorOps::float_equal).
268    /// Int => [equal](burn_tensor::ops::IntTensorOps::int_equal).
269    /// Bool => [equal](burn_tensor::ops::BoolTensorOps::bool_equal).
270    Equal(BinaryOpIr),
271    /// Operation corresponding to:
272    ///
273    /// Float => [repeat dim](burn_tensor::ops::FloatTensorOps::float_repeat_dim).
274    /// Int => [repeat dim](burn_tensor::ops::IntTensorOps::int_repeat_dim).
275    /// Bool => [repeat dim](burn_tensor::ops::BoolTensorOps::bool_repeat_dim).
276    RepeatDim(RepeatDimOpIr),
277    /// Operation corresponding to:
278    ///
279    /// Float => [cat](burn_tensor::ops::FloatTensorOps::float_cat).
280    /// Int => [cat](burn_tensor::ops::IntTensorOps::int_cat).
281    /// Bool => [cat](burn_tensor::ops::BoolTensorOps::bool_cat).
282    Cat(CatOpIr),
283    /// Cast operation, no direct operation and should be supported by fusion backend.
284    Cast(UnaryOpIr),
285
286    /// Operation corresponding to:
287    ///
288    /// Float => [cumsum](burn_tensor::ops::FloatTensorOps::float_cumsum).
289    /// Int => [cumsum](burn_tensor::ops::IntTensorOps::int_cumsum).
290    CumSum(DimOpIr),
291
292    /// Operation corresponding to:
293    ///
294    /// Float => [cumprod](burn_tensor::ops::FloatTensorOps::float_cumprod).
295    /// Int => [cumprod](burn_tensor::ops::IntTensorOps::int_cumprod).
296    CumProd(DimOpIr),
297    /// Float => [cummin](burn_tensor::ops::FloatTensorOps::float_cummin).
298    /// Int => [cummin](burn_tensor::ops::IntTensorOps::int_cummin).
299    CumMin(DimOpIr),
300
301    /// Operation corresponding to:
302    ///
303    /// Float => [cummax](burn_tensor::ops::FloatTensorOps::float_cummax).
304    /// Int => [cummax](burn_tensor::ops::IntTensorOps::int_cummax).
305    CumMax(DimOpIr),
306    /// Operation corresponding to:
307    ///
308    /// Float => [empty](burn_tensor::ops::FloatTensorOps::float_empty).
309    /// Int => [empty](burn_tensor::ops::IntTensorOps::int_empty).
310    /// Bool => [empty](burn_tensor::ops::BoolTensorOps::bool_empty).
311    Empty(TensorIr),
312}
313
314/// Numeric operations on int and float tensors.
315#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
316pub enum NumericOperationIr {
317    /// Operation corresponding to:
318    ///
319    /// Float => [add](burn_tensor::ops::FloatTensorOps::float_add).
320    /// Int => [add](burn_tensor::ops::IntTensorOps::int_add).
321    Add(BinaryOpIr),
322    /// Operation corresponding to:
323    ///
324    /// Float => [add scalar](burn_tensor::ops::FloatTensorOps::float_add_scalar).
325    /// Int => [add scalar](burn_tensor::ops::IntTensorOps::int_add_scalar).
326    AddScalar(ScalarOpIr),
327    /// Operation corresponding to:
328    ///
329    /// Float => [sub](burn_tensor::ops::FloatTensorOps::float_sub).
330    /// Int => [sub](burn_tensor::ops::IntTensorOps::int_sub).
331    Sub(BinaryOpIr),
332    /// Operation corresponding to:
333    ///
334    /// Float => [sub scalar](burn_tensor::ops::FloatTensorOps::float_sub_scalar).
335    /// Int => [sub scalar](burn_tensor::ops::IntTensorOps::int_sub_scalar).
336    SubScalar(ScalarOpIr),
337    /// Operation corresponding to:
338    ///
339    /// Float => [div](burn_tensor::ops::FloatTensorOps::float_div).
340    /// Int => [div](burn_tensor::ops::IntTensorOps::int_div).
341    Div(BinaryOpIr),
342    /// Operation corresponding to:
343    ///
344    /// Float => [div scalar](burn_tensor::ops::FloatTensorOps::float_div_scalar).
345    /// Int => [div scalar](burn_tensor::ops::IntTensorOps::int_div_scalar).
346    DivScalar(ScalarOpIr),
347    /// Operation corresponding to:
348    ///
349    /// Float => [rem](burn_tensor::ops::FloatTensorOps::float_remainder).
350    /// Int => [rem](burn_tensor::ops::IntTensorOps::int_remainder).
351    Rem(BinaryOpIr),
352    /// Operation corresponding to:
353    ///
354    /// Float => [rem scalar](burn_tensor::ops::FloatTensorOps::float_remainder_scalar).
355    /// Int => [rem scalar](burn_tensor::ops::IntTensorOps::int_remainder_scalar).
356    RemScalar(ScalarOpIr),
357    /// Operation corresponding to:
358    ///
359    /// Float => [mul](burn_tensor::ops::FloatTensorOps::float_mul).
360    /// Int => [mul](burn_tensor::ops::IntTensorOps::int_mul).
361    Mul(BinaryOpIr),
362    /// Operation corresponding to:
363    ///
364    /// Float => [mul scalar](burn_tensor::ops::FloatTensorOps::float_mul_scalar).
365    /// Int => [mul scalar](burn_tensor::ops::IntTensorOps::int_mul_scalar).
366    MulScalar(ScalarOpIr),
367    /// Operation corresponding to:
368    ///
369    /// Float => [abs](burn_tensor::ops::FloatTensorOps::float_abs).
370    /// Int => [abs](burn_tensor::ops::IntTensorOps::int_abs).
371    Abs(UnaryOpIr),
372    /// Operation corresponding to:
373    ///
374    /// Float => [ones](burn_tensor::ops::FloatTensorOps::float_ones).
375    /// Int => [ones](burn_tensor::ops::IntTensorOps::int_ones).
376    Ones(TensorIr),
377    /// Operation corresponding to:
378    ///
379    /// Float => [zeros](burn_tensor::ops::FloatTensorOps::float_zeros).
380    /// Int => [zeros](burn_tensor::ops::IntTensorOps::int_zeros).
381    Zeros(TensorIr),
382    /// Operation corresponding to:
383    ///
384    /// Float => [full](burn_tensor::ops::FloatTensorOps::float_full).
385    /// Int => [full](burn_tensor::ops::IntTensorOps::int_full).
386    Full((TensorIr, ScalarIr)),
387    /// Operation corresponding to:
388    ///
389    /// Float => [gather](burn_tensor::ops::FloatTensorOps::float_gather).
390    /// Int => [gather](burn_tensor::ops::IntTensorOps::int_gather).
391    Gather(GatherOpIr),
392    /// Operation corresponding to:
393    ///
394    /// Float => [scatter](burn_tensor::ops::FloatTensorOps::float_scatter).
395    /// Int => [scatter](burn_tensor::ops::IntTensorOps::int_scatter).
396    Scatter(ScatterOpIr),
397    /// Operation corresponding to:
398    ///
399    /// Float => [select](burn_tensor::ops::FloatTensorOps::float_select).
400    /// Int => [select](burn_tensor::ops::IntTensorOps::int_select).
401    Select(SelectOpIr),
402    /// Operation corresponding to:
403    ///
404    /// Float => [select assign](burn_tensor::ops::FloatTensorOps::float_select_assign).
405    /// Int => [select assign](burn_tensor::ops::IntTensorOps::int_select_assign).
406    SelectAssign(SelectAssignOpIr),
407    /// Operation corresponding to:
408    ///
409    /// Float => [mask where](burn_tensor::ops::FloatTensorOps::float_mask_where).
410    /// Int => [mask where](burn_tensor::ops::IntTensorOps::int_mask_where).
411    MaskWhere(MaskWhereOpIr),
412    /// Operation corresponding to:
413    ///
414    /// Float => [mask fill](burn_tensor::ops::FloatTensorOps::float_mask_fill).
415    /// Int => [mask fill](burn_tensor::ops::IntTensorOps::int_mask_fill).
416    MaskFill(MaskFillOpIr),
417    /// Operation corresponding to:
418    ///
419    /// Float => [mean dim](burn_tensor::ops::FloatTensorOps::float_mean_dim).
420    /// Int => [mean dim](burn_tensor::ops::IntTensorOps::int_mean_dim).
421    MeanDim(ReduceDimOpIr),
422    /// Operation corresponding to:
423    ///
424    /// Float => [mean](burn_tensor::ops::FloatTensorOps::float_mean).
425    /// Int => [mean](burn_tensor::ops::IntTensorOps::int_mean).
426    Mean(UnaryOpIr),
427    /// Operation corresponding to:
428    ///
429    /// Float => [sum](burn_tensor::ops::FloatTensorOps::float_sum).
430    /// Int => [sum](burn_tensor::ops::IntTensorOps::int_sum).
431    Sum(UnaryOpIr),
432    /// Operation corresponding to:
433    ///
434    /// Float => [sum dim](burn_tensor::ops::FloatTensorOps::float_sum_dim).
435    /// Int => [sum dim](burn_tensor::ops::IntTensorOps::int_sum_dim).
436    SumDim(ReduceDimOpIr),
437
438    /// Operation corresponding to:
439    ///
440    /// Float => [prod](burn_tensor::ops::FloatTensorOps::float_prod).
441    /// Int => [prod](burn_tensor::ops::IntTensorOps::int_prod).
442    Prod(UnaryOpIr),
443
444    /// Operation corresponding to:
445    ///
446    /// Float => [prod dim](burn_tensor::ops::FloatTensorOps::float_prod_dim).
447    /// Int => [prod dim](burn_tensor::ops::IntTensorOps::int_prod_dim).
448    ProdDim(ReduceDimOpIr),
449
450    /// Operation corresponding to:
451    ///
452    /// Float => [equal elem](burn_tensor::ops::FloatTensorOps::float_equal_elem).
453    /// Int => [equal elem](burn_tensor::ops::IntTensorOps::int_equal_elem).
454    EqualElem(ScalarOpIr),
455    /// Operation corresponding to:
456    ///
457    /// Float => [greater](burn_tensor::ops::FloatTensorOps::float_greater).
458    /// Int => [greater](burn_tensor::ops::IntTensorOps::int_greater).
459    Greater(BinaryOpIr),
460    /// Operation corresponding to:
461    ///
462    /// Float => [greater elem](burn_tensor::ops::FloatTensorOps::float_greater_elem).
463    /// Int => [greater elem](burn_tensor::ops::IntTensorOps::int_greater_elem).
464    GreaterElem(ScalarOpIr),
465    /// Operation corresponding to:
466    ///
467    /// Float => [greater equal](burn_tensor::ops::FloatTensorOps::float_greater_elem).
468    /// Int => [greater elem](burn_tensor::ops::IntTensorOps::int_greater_elem).
469    GreaterEqual(BinaryOpIr),
470    /// Operation corresponding to:
471    ///
472    /// Float => [greater equal elem](burn_tensor::ops::FloatTensorOps::float_greater_equal_elem).
473    /// Int => [greater equal elem](burn_tensor::ops::IntTensorOps::int_greater_equal_elem).
474    GreaterEqualElem(ScalarOpIr),
475    /// Operation corresponding to:
476    ///
477    /// Float => [lower](burn_tensor::ops::FloatTensorOps::float_lower).
478    /// Int => [lower](burn_tensor::ops::IntTensorOps::int_lower).
479    Lower(BinaryOpIr),
480    /// Operation corresponding to:
481    ///
482    /// Float => [lower elem](burn_tensor::ops::FloatTensorOps::float_lower_elem).
483    /// Int => [lower elem](burn_tensor::ops::IntTensorOps::int_lower_elem).
484    LowerElem(ScalarOpIr),
485    /// Operation corresponding to:
486    ///
487    /// Float => [lower equal](burn_tensor::ops::FloatTensorOps::float_lower_equal).
488    /// Int => [lower equal](burn_tensor::ops::IntTensorOps::int_lower_equal).
489    LowerEqual(BinaryOpIr),
490    /// Operation corresponding to:
491    ///
492    /// Float => [lower equal elem](burn_tensor::ops::FloatTensorOps::float_lower_equal_elem).
493    /// Int => [lower equal elem](burn_tensor::ops::IntTensorOps::int_lower_equal_elem).
494    LowerEqualElem(ScalarOpIr),
495    /// Operation corresponding to:
496    ///
497    /// Float => [argmax](burn_tensor::ops::FloatTensorOps::float_argmax).
498    /// Int => [argmax](burn_tensor::ops::IntTensorOps::int_argmax).
499    ArgMax(ReduceDimOpIr),
500    /// Operation corresponding to:
501    ///
502    /// Float => [argmin](burn_tensor::ops::FloatTensorOps::float_argmin).
503    /// Int => [argmin](burn_tensor::ops::IntTensorOps::int_argmin).
504    ArgMin(ReduceDimOpIr),
505    /// Operation corresponding to:
506    ///
507    /// Float => [max](burn_tensor::ops::FloatTensorOps::float_max).
508    /// Int => [max](burn_tensor::ops::IntTensorOps::int_max).
509    Max(UnaryOpIr),
510    /// Operation corresponding to:
511    ///
512    /// Float => [max dim with indices](burn_tensor::ops::FloatTensorOps::float_max_dim_with_indices).
513    /// Int => [max dim with indices](burn_tensor::ops::IntTensorOps::int_max_dim_with_indices).
514    MaxDimWithIndices(ReduceDimWithIndicesOpIr),
515    /// Operation corresponding to:
516    ///
517    /// Float => [min dim with indices](burn_tensor::ops::FloatTensorOps::float_min_dim_with_indices).
518    /// Int => [min dim with indices](burn_tensor::ops::IntTensorOps::int_min_dim_with_indices).
519    MinDimWithIndices(ReduceDimWithIndicesOpIr),
520    /// Operation corresponding to:
521    ///
522    /// Float => [min](burn_tensor::ops::FloatTensorOps::float_min).
523    /// Int => [min](burn_tensor::ops::IntTensorOps::int_min).
524    Min(UnaryOpIr),
525    /// Operation corresponding to:
526    ///
527    /// Float => [max dim](burn_tensor::ops::FloatTensorOps::float_max_dim).
528    /// Int => [max dim](burn_tensor::ops::IntTensorOps::int_max_dim).
529    MaxDim(ReduceDimOpIr),
530    /// Operation corresponding to:
531    ///
532    /// Float => [min dim](burn_tensor::ops::FloatTensorOps::float_min_dim).
533    /// Int => [min dim](burn_tensor::ops::IntTensorOps::int_min_dim).
534    MinDim(ReduceDimOpIr),
535    /// Operation corresponding to:
536    ///
537    /// Float => [max_abs](burn_tensor::ops::FloatTensorOps::float_max_abs).
538    /// Int => [max_abs](burn_tensor::ops::IntTensorOps::int_max_abs).
539    MaxAbs(UnaryOpIr),
540    /// Operation corresponding to:
541    ///
542    /// Float => [max_abs dim](burn_tensor::ops::FloatTensorOps::float_max_abs_dim).
543    /// Int => [max_abs dim](burn_tensor::ops::IntTensorOps::int_max_abs_dim).
544    MaxAbsDim(ReduceDimOpIr),
545    /// Operation corresponding to:
546    ///
547    /// Float => [clamp](burn_tensor::ops::FloatTensorOps::float_clamp).
548    /// Int => [clamp](burn_tensor::ops::IntTensorOps::int_clamp).
549    Clamp(ClampOpIr),
550    /// Operation corresponding to:
551    ///
552    /// Int => [random](burn_tensor::ops::IntTensorOps::int_random).
553    IntRandom(RandomOpIr),
554    /// Operation corresponding to:
555    ///
556    /// Float => [powf](burn_tensor::ops::FloatTensorOps::float_powf).
557    /// Int => [powf](burn_tensor::ops::IntTensorOps::int_powf).
558    Powf(BinaryOpIr),
559}
560
561/// Operation intermediate representation specific to an int tensor.
562#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
563pub enum IntOperationIr {
564    /// Operation corresponding to [into float](burn_tensor::ops::IntTensorOps::int_into_float).
565    IntoFloat(UnaryOpIr),
566    /// Operation corresponding to:
567    ///
568    /// Int => [bitwise and](burn_tensor::ops::IntTensorOps::bitwise_and).
569    BitwiseAnd(BinaryOpIr),
570    /// Operation corresponding to:
571    ///
572    /// Int => [bitwise and scalar](burn_tensor::ops::IntTensorOps::bitwise_and_scalar).
573    BitwiseAndScalar(ScalarOpIr),
574    /// Operation corresponding to:
575    ///
576    /// Int => [bitwise or](burn_tensor::ops::IntTensorOps::bitwise_or).
577    BitwiseOr(BinaryOpIr),
578    /// Operation corresponding to:
579    ///
580    /// Int => [bitwise or scalar](burn_tensor::ops::IntTensorOps::bitwise_or_scalar).
581    BitwiseOrScalar(ScalarOpIr),
582    /// Operation corresponding to:
583    ///
584    /// Int => [bitwise xor](burn_tensor::ops::IntTensorOps::bitwise_xor).
585    BitwiseXor(BinaryOpIr),
586    /// Operation corresponding to:
587    ///
588    /// Int => [bitwise xor scalar](burn_tensor::ops::IntTensorOps::bitwise_xor_scalar).
589    BitwiseXorScalar(ScalarOpIr),
590    /// Operation corresponding to:
591    ///
592    /// Int => [bitwise not](burn_tensor::ops::IntTensorOps::bitwise_not).
593    BitwiseNot(UnaryOpIr),
594    /// Operation corresponding to:
595    ///
596    /// Int => [bitwise left shift](burn_tensor::ops::IntTensorOps::bitwise_left_shift).
597    BitwiseLeftShift(BinaryOpIr),
598    /// Operation corresponding to:
599    ///
600    /// Int => [bitwise left shift scalar](burn_tensor::ops::IntTensorOps::bitwise_left_shift_scalar).
601    BitwiseLeftShiftScalar(ScalarOpIr),
602    /// Operation corresponding to:
603    ///
604    /// Int => [bitwise right shift](burn_tensor::ops::IntTensorOps::bitwise_right_shift).
605    BitwiseRightShift(BinaryOpIr),
606    /// Operation corresponding to:
607    ///
608    /// Int => [bitwise right shift scalar](burn_tensor::ops::IntTensorOps::bitwise_right_shift_scalar).
609    BitwiseRightShiftScalar(ScalarOpIr),
610    /// Operation corresponding to [matmul](burn_tensor::ops::IntTensorOps::int_matmul).
611    Matmul(BinaryOpIr),
612}
613
614/// Operation intermediate representation specific to a bool tensor.
615#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
616pub enum BoolOperationIr {
617    /// Operation corresponding to:
618    /// [ones](burn_tensor::ops::BoolTensorOps::bool_zeros).
619    Zeros(TensorIr),
620    /// Operation corresponding to:
621    /// [ones](burn_tensor::ops::BoolTensorOps::bool_ones).
622    Ones(TensorIr),
623    /// Operation corresponding to [into float](burn_tensor::ops::BoolTensorOps::bool_into_float).
624    IntoFloat(UnaryOpIr),
625    /// Operation corresponding to [into int](burn_tensor::ops::BoolTensorOps::bool_into_int).
626    IntoInt(UnaryOpIr),
627    /// Operation corresponding to [not](burn_tensor::ops::BoolTensorOps::bool_not).
628    Not(UnaryOpIr),
629    /// Operation corresponding to [and](burn_tensor::ops::BoolTensorOps::bool_and).
630    And(BinaryOpIr),
631    /// Operation corresponding to [or](burn_tensor::ops::BoolTensorOps::bool_or).
632    Or(BinaryOpIr),
633}
634
635/// Swap dim operation intermediate representation.
636#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
637pub struct SwapDimsOpIr {
638    /// Input tensor intermediate representation.
639    pub input: TensorIr,
640    /// Output tensor intermediate representation.
641    pub out: TensorIr,
642    /// The first dim to swap.
643    pub dim1: usize,
644    /// The second dim to swap.
645    pub dim2: usize,
646}
647
648/// Permute operation intermediate representation.
649#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
650pub struct PermuteOpIr {
651    /// Input tensor intermediate representation.
652    pub input: TensorIr,
653    /// Output tensor intermediate representation.
654    pub out: TensorIr,
655    /// The new order of the dimensions.
656    pub axes: Vec<usize>,
657}
658
659/// Expand operation intermediate representation.
660#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
661pub struct ExpandOpIr {
662    /// Input tensor intermediate representation.
663    pub input: TensorIr,
664    /// Output tensor intermediate representation.
665    pub out: TensorIr,
666    /// The new shape.
667    pub shape: Shape,
668}
669
670/// Unfold operation intermediate representation.
671#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
672pub struct UnfoldOpIr {
673    /// Input tensor intermediate representation.
674    pub input: TensorIr,
675    /// Output tensor intermediate representation.
676    pub out: TensorIr,
677
678    /// The selected dim.
679    pub dim: usize,
680    /// The window size.
681    pub size: usize,
682    /// The window step along dim.
683    pub step: usize,
684}
685
686/// Flip operation intermediate representation.
687#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
688pub struct FlipOpIr {
689    /// Input tensor intermediate representation.
690    pub input: TensorIr,
691    /// Output tensor intermediate representation.
692    pub out: TensorIr,
693    /// The dimensions to flip.
694    pub axes: Vec<usize>,
695}
696
697#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
698#[allow(missing_docs)]
699pub struct RandomOpIr {
700    pub out: TensorIr,
701    pub distribution: Distribution,
702}
703
704#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
705/// Declares a tensor has been initialized.
706///
707/// It is necessary to register for proper orphan detection and avoid memory leak.
708pub struct InitOperationIr {
709    /// The initialized tensor.
710    pub out: TensorIr,
711}
712
713#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
714#[allow(missing_docs)]
715pub struct BinaryOpIr {
716    pub lhs: TensorIr,
717    pub rhs: TensorIr,
718    pub out: TensorIr,
719}
720
721#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
722#[allow(missing_docs)]
723pub struct CrossOpIr {
724    pub lhs: TensorIr,
725    pub rhs: TensorIr,
726    pub out: TensorIr,
727    pub dim: usize,
728}
729
730#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
731#[allow(missing_docs)]
732pub struct UnaryOpIr {
733    pub input: TensorIr,
734    pub out: TensorIr,
735}
736
737#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
738#[allow(missing_docs)]
739pub struct ScalarOpIr {
740    pub lhs: TensorIr,
741    // TODO: Make that an enum with `Value` and `Id` variants for relative/global
742    // conversion.
743    pub rhs: ScalarIr,
744    pub out: TensorIr,
745}
746
747#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]
748#[allow(missing_docs)]
749pub struct ReduceDimOpIr {
750    pub input: TensorIr,
751    pub out: TensorIr,
752    pub axis: usize,
753}
754
755/// IR for operations that operate along a dimension without reducing it.
756/// Unlike `ReduceDimOpIr`, the output shape is the same as the input shape.
757#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Hash)]
758#[allow(missing_docs)]
759pub struct DimOpIr {
760    pub input: TensorIr,
761    pub out: TensorIr,
762    pub axis: usize,
763}
764
765#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
766#[allow(missing_docs)]
767pub struct GatherOpIr {
768    pub tensor: TensorIr,
769    pub dim: usize,
770    pub indices: TensorIr,
771    pub out: TensorIr,
772}
773
774#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
775#[allow(missing_docs)]
776pub struct ScatterOpIr {
777    pub tensor: TensorIr,
778    pub dim: usize,
779    pub indices: TensorIr,
780    pub value: TensorIr,
781    pub out: TensorIr,
782}
783
784#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
785#[allow(missing_docs)]
786pub struct SelectOpIr {
787    pub tensor: TensorIr,
788    pub dim: usize,
789    pub indices: TensorIr,
790    pub out: TensorIr,
791}
792
793#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
794#[allow(missing_docs)]
795pub struct SelectAssignOpIr {
796    pub tensor: TensorIr,
797    pub dim: usize,
798    pub indices: TensorIr,
799    pub value: TensorIr,
800    pub out: TensorIr,
801}
802
803#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
804#[allow(missing_docs)]
805pub struct SliceOpIr {
806    pub tensor: TensorIr,
807    pub ranges: Vec<Slice>,
808    pub out: TensorIr,
809}
810
811#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
812#[allow(missing_docs)]
813pub struct SliceAssignOpIr {
814    pub tensor: TensorIr,
815    pub ranges: Vec<burn_tensor::Slice>,
816    pub value: TensorIr,
817    pub out: TensorIr,
818}
819
820#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
821#[allow(missing_docs)]
822pub struct MaskWhereOpIr {
823    pub tensor: TensorIr,
824    pub mask: TensorIr,
825    pub value: TensorIr,
826    pub out: TensorIr,
827}
828
829#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
830#[allow(missing_docs)]
831pub struct MaskFillOpIr {
832    pub tensor: TensorIr,
833    pub mask: TensorIr,
834    pub value: ScalarIr,
835    pub out: TensorIr,
836}
837
838#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
839#[allow(missing_docs)]
840pub struct ClampOpIr {
841    pub tensor: TensorIr,
842    pub min: ScalarIr,
843    pub max: ScalarIr,
844    pub out: TensorIr,
845}
846
847#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
848#[allow(missing_docs)]
849pub struct RepeatDimOpIr {
850    pub tensor: TensorIr,
851    pub dim: usize,
852    pub times: usize,
853    pub out: TensorIr,
854}
855
856#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
857#[allow(missing_docs)]
858pub struct CatOpIr {
859    pub tensors: Vec<TensorIr>,
860    pub dim: usize,
861    pub out: TensorIr,
862}
863
864#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
865#[allow(missing_docs)]
866pub struct ReduceDimWithIndicesOpIr {
867    pub tensor: TensorIr,
868    pub dim: usize,
869    pub out: TensorIr,
870    pub out_indices: TensorIr,
871}
872
873#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
874#[allow(missing_docs)]
875pub struct EmbeddingOpIr {
876    pub weights: TensorIr,
877    pub indices: TensorIr,
878    pub out: TensorIr,
879}
880
881#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
882#[allow(missing_docs)]
883pub struct EmbeddingBackwardOpIr {
884    pub weights: TensorIr,
885    pub out_grad: TensorIr,
886    pub indices: TensorIr,
887    pub out: TensorIr,
888}
889
890#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
891#[allow(missing_docs)]
892pub struct Conv1dOpIr {
893    pub x: TensorIr,
894    pub weight: TensorIr,
895    pub bias: Option<TensorIr>,
896    pub options: Conv1dOptionsIr,
897    pub out: TensorIr,
898}
899
900#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
901#[allow(missing_docs)]
902pub struct Conv2dOpIr {
903    pub x: TensorIr,
904    pub weight: TensorIr,
905    pub bias: Option<TensorIr>,
906    pub options: Conv2dOptionsIr,
907    pub out: TensorIr,
908}
909
910#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
911#[allow(missing_docs)]
912pub struct DeformConv2dOpIr {
913    pub x: TensorIr,
914    pub offset: TensorIr,
915    pub weight: TensorIr,
916    pub mask: Option<TensorIr>,
917    pub bias: Option<TensorIr>,
918    pub options: DeformableConv2dOptionsIr,
919    pub out: TensorIr,
920}
921
922#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
923#[allow(missing_docs)]
924pub struct DeformConv2dBackwardOpIr {
925    pub x: TensorIr,
926    pub offset: TensorIr,
927    pub weight: TensorIr,
928    pub mask: Option<TensorIr>,
929    pub bias: Option<TensorIr>,
930    pub out_grad: TensorIr,
931    pub options: DeformableConv2dOptionsIr,
932    pub input_grad: TensorIr,
933    pub offset_grad: TensorIr,
934    pub weight_grad: TensorIr,
935    pub mask_grad: Option<TensorIr>,
936    pub bias_grad: Option<TensorIr>,
937}
938
939#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
940#[allow(missing_docs)]
941pub struct Conv3dOpIr {
942    pub x: TensorIr,
943    pub weight: TensorIr,
944    pub bias: Option<TensorIr>,
945    pub options: Conv3dOptionsIr,
946    pub out: TensorIr,
947}
948
949#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
950#[allow(missing_docs)]
951pub struct ConvTranspose1dOpIr {
952    pub x: TensorIr,
953    pub weight: TensorIr,
954    pub bias: Option<TensorIr>,
955    pub options: ConvTranspose1dOptionsIr,
956    pub out: TensorIr,
957}
958
959#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
960#[allow(missing_docs)]
961pub struct ConvTranspose2dOpIr {
962    pub x: TensorIr,
963    pub weight: TensorIr,
964    pub bias: Option<TensorIr>,
965    pub options: ConvTranspose2dOptionsIr,
966    pub out: TensorIr,
967}
968
969#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
970#[allow(missing_docs)]
971pub struct ConvTranspose3dOpIr {
972    pub x: TensorIr,
973    pub weight: TensorIr,
974    pub bias: Option<TensorIr>,
975    pub options: ConvTranspose3dOptionsIr,
976    pub out: TensorIr,
977}
978
979#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
980#[allow(missing_docs)]
981pub struct Conv1dOptionsIr {
982    pub stride: [usize; 1],
983    pub padding: [usize; 1],
984    pub dilation: [usize; 1],
985    pub groups: usize,
986}
987
988#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
989#[allow(missing_docs)]
990pub struct Conv2dOptionsIr {
991    pub stride: [usize; 2],
992    pub padding: [usize; 2],
993    pub dilation: [usize; 2],
994    pub groups: usize,
995}
996
997#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
998#[allow(missing_docs)]
999pub struct DeformableConv2dOptionsIr {
1000    pub stride: [usize; 2],
1001    pub padding: [usize; 2],
1002    pub dilation: [usize; 2],
1003    pub weight_groups: usize,
1004    pub offset_groups: usize,
1005}
1006
1007#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1008#[allow(missing_docs)]
1009pub struct Conv3dOptionsIr {
1010    pub stride: [usize; 3],
1011    pub padding: [usize; 3],
1012    pub dilation: [usize; 3],
1013    pub groups: usize,
1014}
1015
1016#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1017#[allow(missing_docs)]
1018pub struct ConvTranspose1dOptionsIr {
1019    pub stride: [usize; 1],
1020    pub padding: [usize; 1],
1021    pub padding_out: [usize; 1],
1022    pub dilation: [usize; 1],
1023    pub groups: usize,
1024}
1025
1026#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1027#[allow(missing_docs)]
1028pub struct ConvTranspose2dOptionsIr {
1029    pub stride: [usize; 2],
1030    pub padding: [usize; 2],
1031    pub padding_out: [usize; 2],
1032    pub dilation: [usize; 2],
1033    pub groups: usize,
1034}
1035
1036#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1037#[allow(missing_docs)]
1038pub struct ConvTranspose3dOptionsIr {
1039    pub stride: [usize; 3],
1040    pub padding: [usize; 3],
1041    pub padding_out: [usize; 3],
1042    pub dilation: [usize; 3],
1043    pub groups: usize,
1044}
1045
1046/// Quantization parameters intermediate representation.
1047#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
1048pub struct QuantizationParametersIr {
1049    /// The scaling factor.
1050    pub scales: TensorIr,
1051}
1052
1053#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1054#[allow(missing_docs)]
1055pub struct QuantizeOpIr {
1056    pub tensor: TensorIr,
1057    pub qparams: QuantizationParametersIr,
1058    pub scheme: QuantScheme,
1059    pub out: TensorIr,
1060}
1061
1062#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1063#[allow(missing_docs)]
1064pub struct DequantizeOpIr {
1065    pub input: TensorIr,
1066    pub out: TensorIr,
1067}
1068
1069impl From<ConvOptions<1>> for Conv1dOptionsIr {
1070    fn from(value: ConvOptions<1>) -> Self {
1071        Self {
1072            stride: value.stride,
1073            padding: value.padding,
1074            dilation: value.dilation,
1075            groups: value.groups,
1076        }
1077    }
1078}
1079
1080impl From<ConvOptions<2>> for Conv2dOptionsIr {
1081    fn from(value: ConvOptions<2>) -> Self {
1082        Self {
1083            stride: value.stride,
1084            padding: value.padding,
1085            dilation: value.dilation,
1086            groups: value.groups,
1087        }
1088    }
1089}
1090
1091impl From<ConvOptions<3>> for Conv3dOptionsIr {
1092    fn from(value: ConvOptions<3>) -> Self {
1093        Self {
1094            stride: value.stride,
1095            padding: value.padding,
1096            dilation: value.dilation,
1097            groups: value.groups,
1098        }
1099    }
1100}
1101
1102impl From<DeformConvOptions<2>> for DeformableConv2dOptionsIr {
1103    fn from(value: DeformConvOptions<2>) -> Self {
1104        Self {
1105            stride: value.stride,
1106            padding: value.padding,
1107            dilation: value.dilation,
1108            weight_groups: value.weight_groups,
1109            offset_groups: value.offset_groups,
1110        }
1111    }
1112}
1113
1114impl From<ConvTransposeOptions<1>> for ConvTranspose1dOptionsIr {
1115    fn from(value: ConvTransposeOptions<1>) -> Self {
1116        Self {
1117            stride: value.stride,
1118            padding: value.padding,
1119            padding_out: value.padding_out,
1120            dilation: value.dilation,
1121            groups: value.groups,
1122        }
1123    }
1124}
1125
1126impl From<ConvTransposeOptions<2>> for ConvTranspose2dOptionsIr {
1127    fn from(value: ConvTransposeOptions<2>) -> Self {
1128        Self {
1129            stride: value.stride,
1130            padding: value.padding,
1131            padding_out: value.padding_out,
1132            dilation: value.dilation,
1133            groups: value.groups,
1134        }
1135    }
1136}
1137
1138impl From<ConvTransposeOptions<3>> for ConvTranspose3dOptionsIr {
1139    fn from(value: ConvTransposeOptions<3>) -> Self {
1140        Self {
1141            stride: value.stride,
1142            padding: value.padding,
1143            padding_out: value.padding_out,
1144            dilation: value.dilation,
1145            groups: value.groups,
1146        }
1147    }
1148}
1149
1150impl From<Conv1dOptionsIr> for ConvOptions<1> {
1151    fn from(val: Conv1dOptionsIr) -> Self {
1152        ConvOptions {
1153            stride: val.stride,
1154            padding: val.padding,
1155            dilation: val.dilation,
1156            groups: val.groups,
1157        }
1158    }
1159}
1160
1161impl From<Conv2dOptionsIr> for ConvOptions<2> {
1162    fn from(val: Conv2dOptionsIr) -> Self {
1163        ConvOptions {
1164            stride: val.stride,
1165            padding: val.padding,
1166            dilation: val.dilation,
1167            groups: val.groups,
1168        }
1169    }
1170}
1171
1172impl From<Conv3dOptionsIr> for ConvOptions<3> {
1173    fn from(val: Conv3dOptionsIr) -> Self {
1174        ConvOptions {
1175            stride: val.stride,
1176            padding: val.padding,
1177            dilation: val.dilation,
1178            groups: val.groups,
1179        }
1180    }
1181}
1182
1183impl From<DeformableConv2dOptionsIr> for DeformConvOptions<2> {
1184    fn from(value: DeformableConv2dOptionsIr) -> Self {
1185        DeformConvOptions {
1186            stride: value.stride,
1187            padding: value.padding,
1188            dilation: value.dilation,
1189            weight_groups: value.weight_groups,
1190            offset_groups: value.offset_groups,
1191        }
1192    }
1193}
1194
1195impl From<ConvTranspose1dOptionsIr> for ConvTransposeOptions<1> {
1196    fn from(val: ConvTranspose1dOptionsIr) -> Self {
1197        ConvTransposeOptions {
1198            stride: val.stride,
1199            padding: val.padding,
1200            padding_out: val.padding_out,
1201            dilation: val.dilation,
1202            groups: val.groups,
1203        }
1204    }
1205}
1206
1207impl From<ConvTranspose2dOptionsIr> for ConvTransposeOptions<2> {
1208    fn from(val: ConvTranspose2dOptionsIr) -> Self {
1209        ConvTransposeOptions {
1210            stride: val.stride,
1211            padding: val.padding,
1212            padding_out: val.padding_out,
1213            dilation: val.dilation,
1214            groups: val.groups,
1215        }
1216    }
1217}
1218
1219impl From<ConvTranspose3dOptionsIr> for ConvTransposeOptions<3> {
1220    fn from(val: ConvTranspose3dOptionsIr) -> Self {
1221        ConvTransposeOptions {
1222            stride: val.stride,
1223            padding: val.padding,
1224            padding_out: val.padding_out,
1225            dilation: val.dilation,
1226            groups: val.groups,
1227        }
1228    }
1229}
1230
1231#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1232#[allow(missing_docs)]
1233pub struct AvgPool1dOpIr {
1234    pub x: TensorIr,
1235    pub kernel_size: usize,
1236    pub stride: usize,
1237    pub padding: usize,
1238    pub count_include_pad: bool,
1239    pub out: TensorIr,
1240}
1241
1242#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1243#[allow(missing_docs)]
1244pub struct AvgPool2dOpIr {
1245    pub x: TensorIr,
1246    pub kernel_size: [usize; 2],
1247    pub stride: [usize; 2],
1248    pub padding: [usize; 2],
1249    pub count_include_pad: bool,
1250    pub out: TensorIr,
1251}
1252
1253#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1254#[allow(missing_docs)]
1255pub struct AvgPool1dBackwardOpIr {
1256    pub x: TensorIr,
1257    pub grad: TensorIr,
1258    pub kernel_size: usize,
1259    pub stride: usize,
1260    pub padding: usize,
1261    pub count_include_pad: bool,
1262    pub out: TensorIr,
1263}
1264
1265#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1266#[allow(missing_docs)]
1267pub struct AvgPool2dBackwardOpIr {
1268    pub x: TensorIr,
1269    pub grad: TensorIr,
1270    pub kernel_size: [usize; 2],
1271    pub stride: [usize; 2],
1272    pub padding: [usize; 2],
1273    pub count_include_pad: bool,
1274    pub out: TensorIr,
1275}
1276
1277#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1278#[allow(missing_docs)]
1279pub struct AdaptiveAvgPool1dOpIr {
1280    pub x: TensorIr,
1281    pub output_size: usize,
1282    pub out: TensorIr,
1283}
1284
1285#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1286#[allow(missing_docs)]
1287pub struct AdaptiveAvgPool2dOpIr {
1288    pub x: TensorIr,
1289    pub output_size: [usize; 2],
1290    pub out: TensorIr,
1291}
1292
1293#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1294#[allow(missing_docs)]
1295pub struct AdaptiveAvgPool1dBackwardOpIr {
1296    pub x: TensorIr,
1297    pub grad: TensorIr,
1298    pub out: TensorIr,
1299}
1300
1301#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1302#[allow(missing_docs)]
1303pub struct AdaptiveAvgPool2dBackwardOpIr {
1304    pub x: TensorIr,
1305    pub grad: TensorIr,
1306    pub out: TensorIr,
1307}
1308
1309#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1310#[allow(missing_docs)]
1311pub struct MaxPool1dOpIr {
1312    pub x: TensorIr,
1313    pub kernel_size: usize,
1314    pub stride: usize,
1315    pub padding: usize,
1316    pub dilation: usize,
1317    pub out: TensorIr,
1318}
1319
1320#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1321#[allow(missing_docs)]
1322pub struct MaxPool1dWithIndicesOpIr {
1323    pub x: TensorIr,
1324    pub kernel_size: usize,
1325    pub stride: usize,
1326    pub padding: usize,
1327    pub dilation: usize,
1328    pub out: TensorIr,
1329    pub out_indices: TensorIr,
1330}
1331
1332#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1333#[allow(missing_docs)]
1334pub struct MaxPool1dWithIndicesBackwardOpIr {
1335    pub x: TensorIr,
1336    pub grad: TensorIr,
1337    pub indices: TensorIr,
1338    pub kernel_size: usize,
1339    pub stride: usize,
1340    pub padding: usize,
1341    pub dilation: usize,
1342    pub out: TensorIr,
1343}
1344
1345#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1346#[allow(missing_docs)]
1347pub struct MaxPool2dOpIr {
1348    pub x: TensorIr,
1349    pub kernel_size: [usize; 2],
1350    pub stride: [usize; 2],
1351    pub padding: [usize; 2],
1352    pub dilation: [usize; 2],
1353    pub out: TensorIr,
1354}
1355
1356#[allow(missing_docs)]
1357#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1358pub struct MaxPool2dWithIndicesOpIr {
1359    pub x: TensorIr,
1360    pub kernel_size: [usize; 2],
1361    pub stride: [usize; 2],
1362    pub padding: [usize; 2],
1363    pub dilation: [usize; 2],
1364    pub out: TensorIr,
1365    pub out_indices: TensorIr,
1366}
1367
1368#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1369#[allow(missing_docs)]
1370pub struct MaxPool2dWithIndicesBackwardOpIr {
1371    pub x: TensorIr,
1372    pub grad: TensorIr,
1373    pub indices: TensorIr,
1374    pub kernel_size: [usize; 2],
1375    pub stride: [usize; 2],
1376    pub padding: [usize; 2],
1377    pub dilation: [usize; 2],
1378    pub out: TensorIr,
1379}
1380
1381#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1382#[allow(missing_docs)]
1383pub enum InterpolateModeIr {
1384    Nearest,
1385    Bilinear,
1386    Bicubic,
1387}
1388
1389#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1390#[allow(missing_docs)]
1391pub struct InterpolateOptionsIr {
1392    pub mode: InterpolateModeIr,
1393}
1394
1395#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1396#[allow(missing_docs)]
1397pub struct InterpolateOpIr {
1398    pub x: TensorIr,
1399    pub output_size: [usize; 2],
1400    pub options: InterpolateOptionsIr,
1401    pub out: TensorIr,
1402}
1403
1404impl From<InterpolateModeIr> for InterpolateMode {
1405    fn from(val: InterpolateModeIr) -> Self {
1406        match val {
1407            InterpolateModeIr::Nearest => Self::Nearest,
1408            InterpolateModeIr::Bilinear => Self::Bilinear,
1409            InterpolateModeIr::Bicubic => Self::Bicubic,
1410        }
1411    }
1412}
1413
1414impl From<InterpolateOptionsIr> for InterpolateOptions {
1415    fn from(val: InterpolateOptionsIr) -> Self {
1416        Self {
1417            mode: val.mode.into(),
1418        }
1419    }
1420}
1421
1422impl From<InterpolateMode> for InterpolateModeIr {
1423    fn from(val: InterpolateMode) -> Self {
1424        match val {
1425            InterpolateMode::Nearest => Self::Nearest,
1426            InterpolateMode::Bilinear => Self::Bilinear,
1427            InterpolateMode::Bicubic => Self::Bicubic,
1428        }
1429    }
1430}
1431
1432impl From<InterpolateOptions> for InterpolateOptionsIr {
1433    fn from(val: InterpolateOptions) -> Self {
1434        Self {
1435            mode: val.mode.into(),
1436        }
1437    }
1438}
1439
1440#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1441#[allow(missing_docs)]
1442pub struct InterpolateBackwardOpIr {
1443    pub x: TensorIr,
1444    pub grad: TensorIr,
1445    pub output_size: [usize; 2],
1446    pub options: InterpolateOptionsIr,
1447    pub out: TensorIr,
1448}
1449
1450impl OperationIr {
1451    /// Get all [tensor](TensorIr) involved with the current operation.
1452    pub fn nodes(&self) -> Vec<&TensorIr> {
1453        match self {
1454            OperationIr::BaseFloat(repr) => repr.nodes(),
1455            OperationIr::BaseInt(repr) => repr.nodes(),
1456            OperationIr::BaseBool(repr) => repr.nodes(),
1457            OperationIr::NumericFloat(_dtype, repr) => repr.nodes(),
1458            OperationIr::NumericInt(_dtype, repr) => repr.nodes(),
1459            OperationIr::Bool(repr) => repr.nodes(),
1460            OperationIr::Int(repr) => repr.nodes(),
1461            OperationIr::Float(_dtype, repr) => repr.nodes(),
1462            OperationIr::Module(repr) => repr.nodes(),
1463            OperationIr::Init(repr) => repr.nodes(),
1464            OperationIr::Custom(repr) => repr.nodes(),
1465            OperationIr::Drop(repr) => vec![repr],
1466        }
1467    }
1468
1469    /// Set the given nodes that are [read write](super::TensorStatus::ReadWrite) to
1470    /// [read only](super::TensorStatus::ReadOnly) in the current operation.
1471    ///
1472    /// Returns the tensor that were updated with their original representation.
1473    pub fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
1474        match self {
1475            OperationIr::BaseFloat(repr) => repr.mark_read_only(nodes),
1476            OperationIr::BaseInt(repr) => repr.mark_read_only(nodes),
1477            OperationIr::BaseBool(repr) => repr.mark_read_only(nodes),
1478            OperationIr::NumericFloat(_dtype, repr) => repr.mark_read_only(nodes),
1479            OperationIr::NumericInt(_dtype, repr) => repr.mark_read_only(nodes),
1480            OperationIr::Bool(repr) => repr.mark_read_only(nodes),
1481            OperationIr::Int(repr) => repr.mark_read_only(nodes),
1482            OperationIr::Float(_dtype, repr) => repr.mark_read_only(nodes),
1483            OperationIr::Module(repr) => repr.mark_read_only(nodes),
1484            OperationIr::Init(_) => Vec::new(),
1485            OperationIr::Drop(repr) => {
1486                let mut output = Vec::new();
1487                repr.mark_read_only(nodes, &mut output);
1488                output
1489            }
1490            OperationIr::Custom(repr) => {
1491                let mut output = Vec::new();
1492
1493                for input in repr.inputs.iter_mut() {
1494                    input.mark_read_only(nodes, &mut output);
1495                }
1496
1497                output
1498            }
1499        }
1500    }
1501}
1502
1503impl BaseOperationIr {
1504    fn nodes(&self) -> Vec<&TensorIr> {
1505        match self {
1506            BaseOperationIr::ToDevice(repr) => vec![repr],
1507            BaseOperationIr::Reshape(repr) => {
1508                vec![&repr.input, &repr.out]
1509            }
1510            BaseOperationIr::SwapDims(repr) => {
1511                vec![&repr.input, &repr.out]
1512            }
1513            BaseOperationIr::Permute(repr) => {
1514                vec![&repr.input, &repr.out]
1515            }
1516
1517            BaseOperationIr::Expand(repr) => {
1518                vec![&repr.input, &repr.out]
1519            }
1520
1521            BaseOperationIr::Flip(repr) => {
1522                vec![&repr.input, &repr.out]
1523            }
1524            BaseOperationIr::Slice(repr) => {
1525                vec![&repr.tensor, &repr.out]
1526            }
1527            BaseOperationIr::SliceAssign(repr) => {
1528                vec![&repr.tensor, &repr.value, &repr.out]
1529            }
1530            BaseOperationIr::Equal(repr) => {
1531                vec![&repr.lhs, &repr.rhs, &repr.out]
1532            }
1533            BaseOperationIr::RepeatDim(repr) => {
1534                vec![&repr.tensor, &repr.out]
1535            }
1536            BaseOperationIr::Cat(repr) => {
1537                let mut tensors: Vec<_> = repr.tensors.iter().collect();
1538                tensors.push(&repr.out);
1539                tensors
1540            }
1541            BaseOperationIr::Cast(repr) => vec![&repr.input, &repr.out],
1542            BaseOperationIr::CumSum(repr) => vec![&repr.input, &repr.out],
1543            BaseOperationIr::CumProd(repr) => vec![&repr.input, &repr.out],
1544            BaseOperationIr::CumMin(repr) => vec![&repr.input, &repr.out],
1545            BaseOperationIr::CumMax(repr) => vec![&repr.input, &repr.out],
1546            BaseOperationIr::Empty(repr) => vec![repr],
1547            BaseOperationIr::Unfold(repr) => {
1548                vec![&repr.input, &repr.out]
1549            }
1550        }
1551    }
1552
1553    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
1554        let mut output = Vec::new();
1555
1556        match self {
1557            BaseOperationIr::ToDevice(repr) => {
1558                repr.mark_read_only(nodes, &mut output);
1559            }
1560            BaseOperationIr::Reshape(repr) => {
1561                repr.input.mark_read_only(nodes, &mut output);
1562            }
1563            BaseOperationIr::SwapDims(repr) => {
1564                repr.input.mark_read_only(nodes, &mut output);
1565            }
1566            BaseOperationIr::Permute(repr) => {
1567                repr.input.mark_read_only(nodes, &mut output);
1568            }
1569
1570            BaseOperationIr::Expand(repr) => {
1571                repr.input.mark_read_only(nodes, &mut output);
1572            }
1573
1574            BaseOperationIr::Flip(repr) => {
1575                repr.input.mark_read_only(nodes, &mut output);
1576            }
1577            BaseOperationIr::Slice(repr) => {
1578                repr.tensor.mark_read_only(nodes, &mut output);
1579            }
1580            BaseOperationIr::SliceAssign(repr) => {
1581                repr.tensor.mark_read_only(nodes, &mut output);
1582                repr.value.mark_read_only(nodes, &mut output);
1583            }
1584            BaseOperationIr::Equal(repr) => {
1585                repr.lhs.mark_read_only(nodes, &mut output);
1586                repr.rhs.mark_read_only(nodes, &mut output);
1587            }
1588            BaseOperationIr::RepeatDim(repr) => {
1589                repr.tensor.mark_read_only(nodes, &mut output);
1590            }
1591            BaseOperationIr::Cat(repr) => {
1592                for t in repr.tensors.iter_mut() {
1593                    t.mark_read_only(nodes, &mut output);
1594                }
1595            }
1596            BaseOperationIr::Cast(repr) => {
1597                repr.input.mark_read_only(nodes, &mut output);
1598            }
1599            BaseOperationIr::CumSum(repr) => {
1600                repr.input.mark_read_only(nodes, &mut output);
1601            }
1602            BaseOperationIr::CumProd(repr) => {
1603                repr.input.mark_read_only(nodes, &mut output);
1604            }
1605            BaseOperationIr::CumMin(repr) => {
1606                repr.input.mark_read_only(nodes, &mut output);
1607            }
1608            BaseOperationIr::CumMax(repr) => {
1609                repr.input.mark_read_only(nodes, &mut output);
1610            }
1611            BaseOperationIr::Unfold(repr) => {
1612                repr.input.mark_read_only(nodes, &mut output);
1613            }
1614            BaseOperationIr::Empty(_) => {}
1615        };
1616
1617        output
1618    }
1619}
1620
1621impl NumericOperationIr {
1622    fn nodes(&self) -> Vec<&TensorIr> {
1623        match self {
1624            NumericOperationIr::Add(repr) => {
1625                vec![&repr.lhs, &repr.rhs, &repr.out]
1626            }
1627            NumericOperationIr::AddScalar(repr) => {
1628                vec![&repr.lhs, &repr.out]
1629            }
1630            NumericOperationIr::Sub(repr) => {
1631                vec![&repr.lhs, &repr.rhs, &repr.out]
1632            }
1633            NumericOperationIr::SubScalar(repr) => {
1634                vec![&repr.lhs, &repr.out]
1635            }
1636            NumericOperationIr::Mul(repr) => {
1637                vec![&repr.lhs, &repr.rhs, &repr.out]
1638            }
1639            NumericOperationIr::MulScalar(repr) => {
1640                vec![&repr.lhs, &repr.out]
1641            }
1642            NumericOperationIr::Div(repr) => {
1643                vec![&repr.lhs, &repr.rhs, &repr.out]
1644            }
1645            NumericOperationIr::DivScalar(repr) => {
1646                vec![&repr.lhs, &repr.out]
1647            }
1648            NumericOperationIr::Rem(repr) => {
1649                vec![&repr.lhs, &repr.rhs, &repr.out]
1650            }
1651            NumericOperationIr::RemScalar(repr) => {
1652                vec![&repr.lhs, &repr.out]
1653            }
1654            NumericOperationIr::Ones(repr) => vec![repr],
1655            NumericOperationIr::Gather(repr) => {
1656                vec![&repr.tensor, &repr.indices, &repr.out]
1657            }
1658            NumericOperationIr::Scatter(repr) => {
1659                vec![&repr.tensor, &repr.indices, &repr.value, &repr.out]
1660            }
1661            NumericOperationIr::Select(repr) => {
1662                vec![&repr.tensor, &repr.indices, &repr.out]
1663            }
1664            NumericOperationIr::SelectAssign(repr) => {
1665                vec![&repr.tensor, &repr.indices, &repr.value, &repr.out]
1666            }
1667            NumericOperationIr::MaskWhere(repr) => {
1668                vec![&repr.tensor, &repr.mask, &repr.value, &repr.out]
1669            }
1670            NumericOperationIr::MaskFill(repr) => {
1671                vec![&repr.tensor, &repr.mask, &repr.out]
1672            }
1673            NumericOperationIr::EqualElem(repr) => {
1674                vec![&repr.lhs, &repr.out]
1675            }
1676            NumericOperationIr::GreaterElem(repr) => {
1677                vec![&repr.lhs, &repr.out]
1678            }
1679            NumericOperationIr::GreaterEqualElem(repr) => {
1680                vec![&repr.lhs, &repr.out]
1681            }
1682            NumericOperationIr::LowerElem(repr) => {
1683                vec![&repr.lhs, &repr.out]
1684            }
1685            NumericOperationIr::LowerEqualElem(repr) => {
1686                vec![&repr.lhs, &repr.out]
1687            }
1688            NumericOperationIr::Greater(repr) => {
1689                vec![&repr.lhs, &repr.rhs, &repr.out]
1690            }
1691            NumericOperationIr::GreaterEqual(repr) => {
1692                vec![&repr.lhs, &repr.rhs, &repr.out]
1693            }
1694            NumericOperationIr::Lower(repr) => {
1695                vec![&repr.lhs, &repr.rhs, &repr.out]
1696            }
1697            NumericOperationIr::LowerEqual(repr) => {
1698                vec![&repr.lhs, &repr.rhs, &repr.out]
1699            }
1700            NumericOperationIr::ArgMax(repr) => {
1701                vec![&repr.input, &repr.out]
1702            }
1703            NumericOperationIr::ArgMin(repr) => {
1704                vec![&repr.input, &repr.out]
1705            }
1706            NumericOperationIr::Clamp(repr) => {
1707                vec![&repr.tensor, &repr.out]
1708            }
1709            NumericOperationIr::Abs(repr) => {
1710                vec![&repr.input, &repr.out]
1711            }
1712            NumericOperationIr::Zeros(repr) => vec![repr],
1713            NumericOperationIr::Full(repr) => vec![&repr.0],
1714            NumericOperationIr::MeanDim(repr) => {
1715                vec![&repr.input, &repr.out]
1716            }
1717            NumericOperationIr::Mean(repr) => {
1718                vec![&repr.input, &repr.out]
1719            }
1720            NumericOperationIr::Sum(repr) => {
1721                vec![&repr.input, &repr.out]
1722            }
1723            NumericOperationIr::SumDim(repr) => {
1724                vec![&repr.input, &repr.out]
1725            }
1726            NumericOperationIr::Prod(repr) => {
1727                vec![&repr.input, &repr.out]
1728            }
1729            NumericOperationIr::ProdDim(repr) => {
1730                vec![&repr.input, &repr.out]
1731            }
1732            NumericOperationIr::Max(repr) => {
1733                vec![&repr.input, &repr.out]
1734            }
1735            NumericOperationIr::MaxDimWithIndices(repr) => {
1736                vec![&repr.tensor, &repr.out_indices, &repr.out]
1737            }
1738            NumericOperationIr::MinDimWithIndices(repr) => {
1739                vec![&repr.tensor, &repr.out_indices, &repr.out]
1740            }
1741            NumericOperationIr::Min(repr) => {
1742                vec![&repr.input, &repr.out]
1743            }
1744            NumericOperationIr::MaxDim(repr) => {
1745                vec![&repr.input, &repr.out]
1746            }
1747            NumericOperationIr::MinDim(repr) => {
1748                vec![&repr.input, &repr.out]
1749            }
1750            NumericOperationIr::MaxAbs(repr) => {
1751                vec![&repr.input, &repr.out]
1752            }
1753            NumericOperationIr::MaxAbsDim(repr) => {
1754                vec![&repr.input, &repr.out]
1755            }
1756            NumericOperationIr::IntRandom(repr) => {
1757                vec![&repr.out]
1758            }
1759            NumericOperationIr::Powf(repr) => {
1760                vec![&repr.lhs, &repr.rhs, &repr.out]
1761            }
1762        }
1763    }
1764    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
1765        let mut output = Vec::new();
1766
1767        match self {
1768            NumericOperationIr::Add(repr) => {
1769                repr.lhs.mark_read_only(nodes, &mut output);
1770                repr.rhs.mark_read_only(nodes, &mut output);
1771            }
1772            NumericOperationIr::AddScalar(repr) => {
1773                repr.lhs.mark_read_only(nodes, &mut output);
1774            }
1775            NumericOperationIr::Sub(repr) => {
1776                repr.lhs.mark_read_only(nodes, &mut output);
1777                repr.rhs.mark_read_only(nodes, &mut output);
1778            }
1779            NumericOperationIr::SubScalar(repr) => {
1780                repr.lhs.mark_read_only(nodes, &mut output);
1781            }
1782            NumericOperationIr::Mul(repr) => {
1783                repr.lhs.mark_read_only(nodes, &mut output);
1784                repr.rhs.mark_read_only(nodes, &mut output);
1785            }
1786            NumericOperationIr::MulScalar(repr) => {
1787                repr.lhs.mark_read_only(nodes, &mut output);
1788            }
1789            NumericOperationIr::Div(repr) => {
1790                repr.lhs.mark_read_only(nodes, &mut output);
1791                repr.rhs.mark_read_only(nodes, &mut output);
1792            }
1793            NumericOperationIr::DivScalar(repr) => {
1794                repr.lhs.mark_read_only(nodes, &mut output);
1795            }
1796            NumericOperationIr::Rem(repr) => {
1797                repr.lhs.mark_read_only(nodes, &mut output);
1798                repr.rhs.mark_read_only(nodes, &mut output);
1799            }
1800            NumericOperationIr::RemScalar(repr) => {
1801                repr.lhs.mark_read_only(nodes, &mut output);
1802            }
1803            NumericOperationIr::Ones(_) => {}
1804            NumericOperationIr::Gather(repr) => {
1805                repr.tensor.mark_read_only(nodes, &mut output);
1806                repr.indices.mark_read_only(nodes, &mut output);
1807            }
1808            NumericOperationIr::Scatter(repr) => {
1809                repr.tensor.mark_read_only(nodes, &mut output);
1810                repr.indices.mark_read_only(nodes, &mut output);
1811                repr.value.mark_read_only(nodes, &mut output);
1812            }
1813            NumericOperationIr::Select(repr) => {
1814                repr.tensor.mark_read_only(nodes, &mut output);
1815                repr.indices.mark_read_only(nodes, &mut output);
1816            }
1817            NumericOperationIr::SelectAssign(repr) => {
1818                repr.tensor.mark_read_only(nodes, &mut output);
1819                repr.indices.mark_read_only(nodes, &mut output);
1820                repr.value.mark_read_only(nodes, &mut output);
1821            }
1822            NumericOperationIr::MaskWhere(repr) => {
1823                repr.tensor.mark_read_only(nodes, &mut output);
1824                repr.mask.mark_read_only(nodes, &mut output);
1825                repr.value.mark_read_only(nodes, &mut output);
1826            }
1827            NumericOperationIr::MaskFill(repr) => {
1828                repr.tensor.mark_read_only(nodes, &mut output);
1829                repr.mask.mark_read_only(nodes, &mut output);
1830            }
1831            NumericOperationIr::EqualElem(repr) => {
1832                repr.lhs.mark_read_only(nodes, &mut output);
1833            }
1834            NumericOperationIr::GreaterElem(repr) => {
1835                repr.lhs.mark_read_only(nodes, &mut output);
1836            }
1837            NumericOperationIr::GreaterEqualElem(repr) => {
1838                repr.lhs.mark_read_only(nodes, &mut output);
1839            }
1840            NumericOperationIr::LowerElem(repr) => {
1841                repr.lhs.mark_read_only(nodes, &mut output);
1842            }
1843            NumericOperationIr::LowerEqualElem(repr) => {
1844                repr.lhs.mark_read_only(nodes, &mut output);
1845            }
1846            NumericOperationIr::Greater(repr) => {
1847                repr.lhs.mark_read_only(nodes, &mut output);
1848                repr.rhs.mark_read_only(nodes, &mut output);
1849            }
1850            NumericOperationIr::GreaterEqual(repr) => {
1851                repr.lhs.mark_read_only(nodes, &mut output);
1852                repr.rhs.mark_read_only(nodes, &mut output);
1853            }
1854            NumericOperationIr::Lower(repr) => {
1855                repr.lhs.mark_read_only(nodes, &mut output);
1856                repr.rhs.mark_read_only(nodes, &mut output);
1857            }
1858            NumericOperationIr::LowerEqual(repr) => {
1859                repr.lhs.mark_read_only(nodes, &mut output);
1860                repr.rhs.mark_read_only(nodes, &mut output);
1861            }
1862            NumericOperationIr::ArgMax(repr) => {
1863                repr.input.mark_read_only(nodes, &mut output);
1864            }
1865            NumericOperationIr::ArgMin(repr) => {
1866                repr.input.mark_read_only(nodes, &mut output);
1867            }
1868            NumericOperationIr::Clamp(repr) => {
1869                repr.tensor.mark_read_only(nodes, &mut output);
1870            }
1871            NumericOperationIr::Abs(repr) => {
1872                repr.input.mark_read_only(nodes, &mut output);
1873            }
1874            NumericOperationIr::Zeros(_) => {}
1875            NumericOperationIr::Full(_) => {}
1876            NumericOperationIr::MeanDim(repr) => {
1877                repr.input.mark_read_only(nodes, &mut output);
1878            }
1879            NumericOperationIr::Mean(repr) => {
1880                repr.input.mark_read_only(nodes, &mut output);
1881            }
1882            NumericOperationIr::Sum(repr) => {
1883                repr.input.mark_read_only(nodes, &mut output);
1884            }
1885            NumericOperationIr::SumDim(repr) => {
1886                repr.input.mark_read_only(nodes, &mut output);
1887            }
1888            NumericOperationIr::Prod(repr) => {
1889                repr.input.mark_read_only(nodes, &mut output);
1890            }
1891            NumericOperationIr::ProdDim(repr) => {
1892                repr.input.mark_read_only(nodes, &mut output);
1893            }
1894            NumericOperationIr::Max(repr) => {
1895                repr.input.mark_read_only(nodes, &mut output);
1896            }
1897            NumericOperationIr::MaxDimWithIndices(repr) => {
1898                repr.tensor.mark_read_only(nodes, &mut output);
1899            }
1900            NumericOperationIr::MinDimWithIndices(repr) => {
1901                repr.tensor.mark_read_only(nodes, &mut output);
1902            }
1903            NumericOperationIr::Min(repr) => {
1904                repr.input.mark_read_only(nodes, &mut output);
1905            }
1906            NumericOperationIr::MaxDim(repr) => {
1907                repr.input.mark_read_only(nodes, &mut output);
1908            }
1909            NumericOperationIr::MinDim(repr) => {
1910                repr.input.mark_read_only(nodes, &mut output);
1911            }
1912            NumericOperationIr::MaxAbs(repr) => {
1913                repr.input.mark_read_only(nodes, &mut output);
1914            }
1915            NumericOperationIr::MaxAbsDim(repr) => {
1916                repr.input.mark_read_only(nodes, &mut output);
1917            }
1918            NumericOperationIr::IntRandom(_) => {}
1919            NumericOperationIr::Powf(repr) => {
1920                repr.lhs.mark_read_only(nodes, &mut output);
1921                repr.rhs.mark_read_only(nodes, &mut output);
1922            }
1923        };
1924
1925        output
1926    }
1927}
1928
1929impl FloatOperationIr {
1930    fn nodes(&self) -> Vec<&TensorIr> {
1931        match self {
1932            FloatOperationIr::Matmul(repr) => {
1933                vec![&repr.lhs, &repr.rhs, &repr.out]
1934            }
1935            FloatOperationIr::Cross(repr) => {
1936                vec![&repr.lhs, &repr.rhs, &repr.out]
1937            }
1938            FloatOperationIr::Random(repr) => vec![&repr.out],
1939            FloatOperationIr::Exp(repr) => vec![&repr.input, &repr.out],
1940            FloatOperationIr::Log(repr) => vec![&repr.input, &repr.out],
1941            FloatOperationIr::Log1p(repr) => vec![&repr.input, &repr.out],
1942            FloatOperationIr::Erf(repr) => vec![&repr.input, &repr.out],
1943            FloatOperationIr::Recip(repr) => vec![&repr.input, &repr.out],
1944            FloatOperationIr::PowfScalar(repr) => vec![&repr.lhs, &repr.out],
1945            FloatOperationIr::Sqrt(repr) => vec![&repr.input, &repr.out],
1946            FloatOperationIr::Cos(repr) => vec![&repr.input, &repr.out],
1947            FloatOperationIr::Sin(repr) => vec![&repr.input, &repr.out],
1948            FloatOperationIr::Tanh(repr) => vec![&repr.input, &repr.out],
1949            FloatOperationIr::Round(repr) => vec![&repr.input, &repr.out],
1950            FloatOperationIr::Floor(repr) => vec![&repr.input, &repr.out],
1951            FloatOperationIr::Ceil(repr) => vec![&repr.input, &repr.out],
1952            FloatOperationIr::Trunc(repr) => vec![&repr.input, &repr.out],
1953            FloatOperationIr::IntoInt(repr) => vec![&repr.input, &repr.out],
1954            FloatOperationIr::Quantize(repr) => vec![&repr.tensor, &repr.qparams.scales, &repr.out],
1955            FloatOperationIr::Dequantize(repr) => vec![&repr.input, &repr.out],
1956            FloatOperationIr::IsNan(repr) => vec![&repr.input, &repr.out],
1957            FloatOperationIr::IsInf(repr) => vec![&repr.input, &repr.out],
1958        }
1959    }
1960
1961    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
1962        let mut output = Vec::new();
1963
1964        match self {
1965            FloatOperationIr::Matmul(repr) => {
1966                repr.lhs.mark_read_only(nodes, &mut output);
1967                repr.rhs.mark_read_only(nodes, &mut output);
1968            }
1969            FloatOperationIr::Cross(repr) => {
1970                repr.lhs.mark_read_only(nodes, &mut output);
1971                repr.rhs.mark_read_only(nodes, &mut output);
1972            }
1973            FloatOperationIr::Random(_) => {}
1974            FloatOperationIr::Exp(repr) => {
1975                repr.input.mark_read_only(nodes, &mut output);
1976            }
1977            FloatOperationIr::Log(repr) => {
1978                repr.input.mark_read_only(nodes, &mut output);
1979            }
1980            FloatOperationIr::Log1p(repr) => {
1981                repr.input.mark_read_only(nodes, &mut output);
1982            }
1983            FloatOperationIr::Erf(repr) => {
1984                repr.input.mark_read_only(nodes, &mut output);
1985            }
1986            FloatOperationIr::Recip(repr) => {
1987                repr.input.mark_read_only(nodes, &mut output);
1988            }
1989            FloatOperationIr::PowfScalar(repr) => {
1990                repr.lhs.mark_read_only(nodes, &mut output);
1991            }
1992            FloatOperationIr::Sqrt(repr) => {
1993                repr.input.mark_read_only(nodes, &mut output);
1994            }
1995            FloatOperationIr::Cos(repr) => {
1996                repr.input.mark_read_only(nodes, &mut output);
1997            }
1998            FloatOperationIr::Sin(repr) => {
1999                repr.input.mark_read_only(nodes, &mut output);
2000            }
2001            FloatOperationIr::Tanh(repr) => {
2002                repr.input.mark_read_only(nodes, &mut output);
2003            }
2004            FloatOperationIr::Round(repr) => {
2005                repr.input.mark_read_only(nodes, &mut output);
2006            }
2007            FloatOperationIr::Floor(repr) => {
2008                repr.input.mark_read_only(nodes, &mut output);
2009            }
2010            FloatOperationIr::Ceil(repr) => {
2011                repr.input.mark_read_only(nodes, &mut output);
2012            }
2013            FloatOperationIr::Trunc(repr) => {
2014                repr.input.mark_read_only(nodes, &mut output);
2015            }
2016            FloatOperationIr::Quantize(repr) => {
2017                repr.tensor.mark_read_only(nodes, &mut output);
2018                repr.qparams.scales.mark_read_only(nodes, &mut output);
2019            }
2020            FloatOperationIr::Dequantize(repr) => {
2021                repr.input.mark_read_only(nodes, &mut output);
2022            }
2023            FloatOperationIr::IntoInt(repr) => {
2024                repr.input.mark_read_only(nodes, &mut output);
2025            }
2026            FloatOperationIr::IsNan(repr) => {
2027                repr.input.mark_read_only(nodes, &mut output);
2028            }
2029            FloatOperationIr::IsInf(repr) => {
2030                repr.input.mark_read_only(nodes, &mut output);
2031            }
2032        };
2033
2034        output
2035    }
2036}
2037
2038impl IntOperationIr {
2039    fn nodes(&self) -> Vec<&TensorIr> {
2040        match self {
2041            IntOperationIr::Matmul(repr) => {
2042                vec![&repr.lhs, &repr.rhs, &repr.out]
2043            }
2044            IntOperationIr::IntoFloat(repr) => vec![&repr.input, &repr.out],
2045            IntOperationIr::BitwiseAnd(repr) => {
2046                vec![&repr.lhs, &repr.rhs, &repr.out]
2047            }
2048            IntOperationIr::BitwiseAndScalar(repr) => {
2049                vec![&repr.lhs, &repr.out]
2050            }
2051            IntOperationIr::BitwiseOr(repr) => {
2052                vec![&repr.lhs, &repr.rhs, &repr.out]
2053            }
2054            IntOperationIr::BitwiseOrScalar(repr) => {
2055                vec![&repr.lhs, &repr.out]
2056            }
2057            IntOperationIr::BitwiseXor(repr) => {
2058                vec![&repr.lhs, &repr.rhs, &repr.out]
2059            }
2060            IntOperationIr::BitwiseXorScalar(repr) => {
2061                vec![&repr.lhs, &repr.out]
2062            }
2063            IntOperationIr::BitwiseNot(repr) => {
2064                vec![&repr.input, &repr.out]
2065            }
2066            IntOperationIr::BitwiseLeftShift(repr) => {
2067                vec![&repr.lhs, &repr.rhs, &repr.out]
2068            }
2069            IntOperationIr::BitwiseLeftShiftScalar(repr) => {
2070                vec![&repr.lhs, &repr.out]
2071            }
2072            IntOperationIr::BitwiseRightShift(repr) => {
2073                vec![&repr.lhs, &repr.rhs, &repr.out]
2074            }
2075            IntOperationIr::BitwiseRightShiftScalar(repr) => {
2076                vec![&repr.lhs, &repr.out]
2077            }
2078        }
2079    }
2080
2081    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2082        let mut output = Vec::new();
2083
2084        match self {
2085            IntOperationIr::Matmul(repr) => {
2086                repr.lhs.mark_read_only(nodes, &mut output);
2087                repr.rhs.mark_read_only(nodes, &mut output);
2088            }
2089            IntOperationIr::IntoFloat(repr) => {
2090                repr.input.mark_read_only(nodes, &mut output);
2091            }
2092            IntOperationIr::BitwiseAnd(repr) => {
2093                repr.lhs.mark_read_only(nodes, &mut output);
2094                repr.rhs.mark_read_only(nodes, &mut output);
2095            }
2096            IntOperationIr::BitwiseAndScalar(repr) => {
2097                repr.lhs.mark_read_only(nodes, &mut output);
2098            }
2099            IntOperationIr::BitwiseOr(repr) => {
2100                repr.lhs.mark_read_only(nodes, &mut output);
2101                repr.rhs.mark_read_only(nodes, &mut output);
2102            }
2103            IntOperationIr::BitwiseOrScalar(repr) => {
2104                repr.lhs.mark_read_only(nodes, &mut output);
2105            }
2106            IntOperationIr::BitwiseXor(repr) => {
2107                repr.lhs.mark_read_only(nodes, &mut output);
2108                repr.rhs.mark_read_only(nodes, &mut output);
2109            }
2110            IntOperationIr::BitwiseXorScalar(repr) => {
2111                repr.lhs.mark_read_only(nodes, &mut output);
2112            }
2113            IntOperationIr::BitwiseNot(repr) => {
2114                repr.input.mark_read_only(nodes, &mut output);
2115            }
2116            IntOperationIr::BitwiseLeftShift(repr) => {
2117                repr.lhs.mark_read_only(nodes, &mut output);
2118                repr.rhs.mark_read_only(nodes, &mut output);
2119            }
2120            IntOperationIr::BitwiseLeftShiftScalar(repr) => {
2121                repr.lhs.mark_read_only(nodes, &mut output);
2122            }
2123            IntOperationIr::BitwiseRightShift(repr) => {
2124                repr.lhs.mark_read_only(nodes, &mut output);
2125                repr.rhs.mark_read_only(nodes, &mut output);
2126            }
2127            IntOperationIr::BitwiseRightShiftScalar(repr) => {
2128                repr.lhs.mark_read_only(nodes, &mut output);
2129            }
2130        };
2131
2132        output
2133    }
2134}
2135
2136impl BoolOperationIr {
2137    fn nodes(&self) -> Vec<&TensorIr> {
2138        match self {
2139            BoolOperationIr::Zeros(repr) => vec![repr],
2140            BoolOperationIr::Ones(repr) => vec![repr],
2141            BoolOperationIr::IntoFloat(repr) => vec![&repr.input, &repr.out],
2142            BoolOperationIr::IntoInt(repr) => vec![&repr.input, &repr.out],
2143            BoolOperationIr::Not(repr) => vec![&repr.input, &repr.out],
2144            BoolOperationIr::And(repr) => vec![&repr.lhs, &repr.rhs, &repr.out],
2145            BoolOperationIr::Or(repr) => vec![&repr.lhs, &repr.rhs, &repr.out],
2146        }
2147    }
2148    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2149        let mut output = Vec::new();
2150
2151        match self {
2152            BoolOperationIr::Zeros(_) => {}
2153            BoolOperationIr::Ones(_) => {}
2154            BoolOperationIr::IntoFloat(repr) => {
2155                repr.input.mark_read_only(nodes, &mut output);
2156            }
2157            BoolOperationIr::IntoInt(repr) => {
2158                repr.input.mark_read_only(nodes, &mut output);
2159            }
2160            BoolOperationIr::Not(repr) => {
2161                repr.input.mark_read_only(nodes, &mut output);
2162            }
2163            BoolOperationIr::And(repr) => {
2164                repr.lhs.mark_read_only(nodes, &mut output);
2165                repr.rhs.mark_read_only(nodes, &mut output);
2166            }
2167            BoolOperationIr::Or(repr) => {
2168                repr.lhs.mark_read_only(nodes, &mut output);
2169                repr.rhs.mark_read_only(nodes, &mut output);
2170            }
2171        };
2172
2173        output
2174    }
2175}
2176
2177impl ModuleOperationIr {
2178    fn nodes(&self) -> Vec<&TensorIr> {
2179        match self {
2180            ModuleOperationIr::Embedding(repr) => {
2181                vec![&repr.weights, &repr.indices, &repr.out]
2182            }
2183            ModuleOperationIr::EmbeddingBackward(repr) => {
2184                vec![&repr.weights, &repr.out_grad, &repr.indices, &repr.out]
2185            }
2186            ModuleOperationIr::Conv1d(repr) => {
2187                if let Some(bias) = &repr.bias {
2188                    vec![&repr.x, &repr.weight, &bias, &repr.out]
2189                } else {
2190                    vec![&repr.x, &repr.weight, &repr.out]
2191                }
2192            }
2193            ModuleOperationIr::Conv2d(repr) => {
2194                if let Some(bias) = &repr.bias {
2195                    vec![&repr.x, &repr.weight, &bias, &repr.out]
2196                } else {
2197                    vec![&repr.x, &repr.weight, &repr.out]
2198                }
2199            }
2200            ModuleOperationIr::Conv3d(repr) => {
2201                if let Some(bias) = &repr.bias {
2202                    vec![&repr.x, &repr.weight, &bias, &repr.out]
2203                } else {
2204                    vec![&repr.x, &repr.weight, &repr.out]
2205                }
2206            }
2207            ModuleOperationIr::DeformableConv2d(repr) => match (&repr.mask, &repr.bias) {
2208                (Some(mask), Some(bias)) => vec![&repr.x, &repr.offset, &repr.weight, &mask, &bias],
2209                (Some(mask), None) => vec![&repr.x, &repr.offset, &repr.weight, &mask],
2210                (None, Some(bias)) => vec![&repr.x, &repr.offset, &repr.weight, &bias],
2211                (None, None) => vec![&repr.x, &repr.offset, &repr.weight],
2212            },
2213            ModuleOperationIr::DeformableConv2dBackward(repr) => {
2214                let mut nodes = Vec::with_capacity(6);
2215                nodes.push(&repr.x);
2216                nodes.push(&repr.offset);
2217                nodes.push(&repr.weight);
2218                nodes.push(&repr.out_grad);
2219
2220                if let Some(mask) = repr.mask.as_ref() {
2221                    nodes.push(mask);
2222                }
2223                if let Some(bias) = repr.bias.as_ref() {
2224                    nodes.push(bias);
2225                }
2226
2227                nodes
2228            }
2229            ModuleOperationIr::ConvTranspose1d(repr) => {
2230                if let Some(bias) = &repr.bias {
2231                    vec![&repr.x, &repr.weight, &bias, &repr.out]
2232                } else {
2233                    vec![&repr.x, &repr.weight, &repr.out]
2234                }
2235            }
2236            ModuleOperationIr::ConvTranspose2d(repr) => {
2237                if let Some(bias) = &repr.bias {
2238                    vec![&repr.x, &repr.weight, &bias, &repr.out]
2239                } else {
2240                    vec![&repr.x, &repr.weight, &repr.out]
2241                }
2242            }
2243            ModuleOperationIr::ConvTranspose3d(repr) => {
2244                if let Some(bias) = &repr.bias {
2245                    vec![&repr.x, &repr.weight, &bias, &repr.out]
2246                } else {
2247                    vec![&repr.x, &repr.weight, &repr.out]
2248                }
2249            }
2250            ModuleOperationIr::AvgPool1d(repr) => {
2251                vec![&repr.x, &repr.out]
2252            }
2253            ModuleOperationIr::AvgPool2d(repr) => {
2254                vec![&repr.x, &repr.out]
2255            }
2256            ModuleOperationIr::AvgPool1dBackward(repr) => {
2257                vec![&repr.x, &repr.out, &repr.grad]
2258            }
2259            ModuleOperationIr::AvgPool2dBackward(repr) => {
2260                vec![&repr.x, &repr.out, &repr.grad]
2261            }
2262            ModuleOperationIr::AdaptiveAvgPool1d(repr) => {
2263                vec![&repr.x, &repr.out]
2264            }
2265            ModuleOperationIr::AdaptiveAvgPool2d(repr) => {
2266                vec![&repr.x, &repr.out]
2267            }
2268            ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => {
2269                vec![&repr.x, &repr.out, &repr.grad]
2270            }
2271            ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => {
2272                vec![&repr.x, &repr.out, &repr.grad]
2273            }
2274            ModuleOperationIr::MaxPool1d(repr) => {
2275                vec![&repr.x, &repr.out]
2276            }
2277            ModuleOperationIr::MaxPool1dWithIndices(repr) => {
2278                vec![&repr.x, &repr.out, &repr.out_indices]
2279            }
2280            ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => {
2281                vec![&repr.x, &repr.out, &repr.indices, &repr.grad]
2282            }
2283            ModuleOperationIr::MaxPool2d(repr) => {
2284                vec![&repr.x, &repr.out]
2285            }
2286            ModuleOperationIr::MaxPool2dWithIndices(repr) => {
2287                vec![&repr.x, &repr.out, &repr.out_indices]
2288            }
2289            ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => {
2290                vec![&repr.x, &repr.out, &repr.indices, &repr.grad]
2291            }
2292            ModuleOperationIr::Interpolate(repr) => {
2293                vec![&repr.x, &repr.out]
2294            }
2295            ModuleOperationIr::InterpolateBackward(repr) => {
2296                vec![&repr.x, &repr.out, &repr.grad]
2297            }
2298        }
2299    }
2300
2301    fn mark_read_only(&mut self, nodes: &[TensorId]) -> Vec<TensorIr> {
2302        let mut output = Vec::new();
2303
2304        match self {
2305            ModuleOperationIr::Embedding(repr) => {
2306                repr.weights.mark_read_only(nodes, &mut output);
2307                repr.indices.mark_read_only(nodes, &mut output);
2308            }
2309            ModuleOperationIr::EmbeddingBackward(repr) => {
2310                repr.weights.mark_read_only(nodes, &mut output);
2311                repr.out_grad.mark_read_only(nodes, &mut output);
2312                repr.indices.mark_read_only(nodes, &mut output);
2313            }
2314            ModuleOperationIr::Conv1d(repr) => {
2315                repr.x.mark_read_only(nodes, &mut output);
2316                repr.weight.mark_read_only(nodes, &mut output);
2317
2318                if let Some(bias) = &mut repr.bias {
2319                    bias.mark_read_only(nodes, &mut output);
2320                }
2321            }
2322            ModuleOperationIr::Conv2d(repr) => {
2323                repr.x.mark_read_only(nodes, &mut output);
2324                repr.weight.mark_read_only(nodes, &mut output);
2325
2326                if let Some(bias) = &mut repr.bias {
2327                    bias.mark_read_only(nodes, &mut output);
2328                }
2329            }
2330            ModuleOperationIr::Conv3d(repr) => {
2331                repr.x.mark_read_only(nodes, &mut output);
2332                repr.weight.mark_read_only(nodes, &mut output);
2333
2334                if let Some(bias) = &mut repr.bias {
2335                    bias.mark_read_only(nodes, &mut output);
2336                }
2337            }
2338            ModuleOperationIr::DeformableConv2d(repr) => {
2339                repr.x.mark_read_only(nodes, &mut output);
2340                repr.weight.mark_read_only(nodes, &mut output);
2341                repr.offset.mark_read_only(nodes, &mut output);
2342
2343                match (&mut repr.mask, &mut repr.bias) {
2344                    (Some(mask), Some(bias)) => {
2345                        mask.mark_read_only(nodes, &mut output);
2346                        bias.mark_read_only(nodes, &mut output);
2347                    }
2348                    (Some(mask), None) => {
2349                        mask.mark_read_only(nodes, &mut output);
2350                    }
2351                    (None, Some(bias)) => {
2352                        bias.mark_read_only(nodes, &mut output);
2353                    }
2354                    (None, None) => {}
2355                };
2356            }
2357            ModuleOperationIr::DeformableConv2dBackward(repr) => {
2358                repr.x.mark_read_only(nodes, &mut output);
2359                repr.weight.mark_read_only(nodes, &mut output);
2360                repr.offset.mark_read_only(nodes, &mut output);
2361                repr.out_grad.mark_read_only(nodes, &mut output);
2362
2363                if let Some(mask) = repr.mask.as_mut() {
2364                    mask.mark_read_only(nodes, &mut output);
2365                }
2366                if let Some(bias) = repr.bias.as_mut() {
2367                    bias.mark_read_only(nodes, &mut output);
2368                }
2369            }
2370            ModuleOperationIr::ConvTranspose1d(repr) => {
2371                repr.x.mark_read_only(nodes, &mut output);
2372                repr.weight.mark_read_only(nodes, &mut output);
2373
2374                if let Some(bias) = &mut repr.bias {
2375                    bias.mark_read_only(nodes, &mut output);
2376                }
2377            }
2378            ModuleOperationIr::ConvTranspose2d(repr) => {
2379                repr.x.mark_read_only(nodes, &mut output);
2380                repr.weight.mark_read_only(nodes, &mut output);
2381
2382                if let Some(bias) = &mut repr.bias {
2383                    bias.mark_read_only(nodes, &mut output);
2384                }
2385            }
2386            ModuleOperationIr::ConvTranspose3d(repr) => {
2387                repr.x.mark_read_only(nodes, &mut output);
2388                repr.weight.mark_read_only(nodes, &mut output);
2389
2390                if let Some(bias) = &mut repr.bias {
2391                    bias.mark_read_only(nodes, &mut output);
2392                }
2393            }
2394            ModuleOperationIr::AvgPool1d(repr) => {
2395                repr.x.mark_read_only(nodes, &mut output);
2396            }
2397            ModuleOperationIr::AvgPool2d(repr) => {
2398                repr.x.mark_read_only(nodes, &mut output);
2399            }
2400            ModuleOperationIr::AvgPool1dBackward(repr) => {
2401                repr.x.mark_read_only(nodes, &mut output);
2402                repr.grad.mark_read_only(nodes, &mut output);
2403            }
2404            ModuleOperationIr::AvgPool2dBackward(repr) => {
2405                repr.x.mark_read_only(nodes, &mut output);
2406                repr.grad.mark_read_only(nodes, &mut output);
2407            }
2408            ModuleOperationIr::AdaptiveAvgPool1d(repr) => {
2409                repr.x.mark_read_only(nodes, &mut output);
2410            }
2411            ModuleOperationIr::AdaptiveAvgPool2d(repr) => {
2412                repr.x.mark_read_only(nodes, &mut output);
2413            }
2414            ModuleOperationIr::AdaptiveAvgPool1dBackward(repr) => {
2415                repr.x.mark_read_only(nodes, &mut output);
2416                repr.grad.mark_read_only(nodes, &mut output);
2417            }
2418            ModuleOperationIr::AdaptiveAvgPool2dBackward(repr) => {
2419                repr.x.mark_read_only(nodes, &mut output);
2420                repr.grad.mark_read_only(nodes, &mut output);
2421            }
2422            ModuleOperationIr::MaxPool1d(repr) => {
2423                repr.x.mark_read_only(nodes, &mut output);
2424            }
2425            ModuleOperationIr::MaxPool1dWithIndices(repr) => {
2426                repr.x.mark_read_only(nodes, &mut output);
2427            }
2428            ModuleOperationIr::MaxPool1dWithIndicesBackward(repr) => {
2429                repr.x.mark_read_only(nodes, &mut output);
2430                repr.grad.mark_read_only(nodes, &mut output);
2431            }
2432            ModuleOperationIr::MaxPool2d(repr) => {
2433                repr.x.mark_read_only(nodes, &mut output);
2434            }
2435            ModuleOperationIr::MaxPool2dWithIndices(repr) => {
2436                repr.x.mark_read_only(nodes, &mut output);
2437            }
2438            ModuleOperationIr::MaxPool2dWithIndicesBackward(repr) => {
2439                repr.x.mark_read_only(nodes, &mut output);
2440                repr.grad.mark_read_only(nodes, &mut output);
2441            }
2442            ModuleOperationIr::Interpolate(repr) => {
2443                repr.x.mark_read_only(nodes, &mut output);
2444            }
2445            ModuleOperationIr::InterpolateBackward(repr) => {
2446                repr.x.mark_read_only(nodes, &mut output);
2447                repr.grad.mark_read_only(nodes, &mut output);
2448            }
2449        };
2450
2451        output
2452    }
2453}
2454
2455impl core::hash::Hash for InitOperationIr {
2456    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
2457        self.out.hash(state);
2458    }
2459}
2460
2461impl InitOperationIr {
2462    fn nodes(&self) -> Vec<&TensorIr> {
2463        vec![&self.out]
2464    }
2465}
2466
2467impl TensorIr {
2468    fn mark_read_only(&mut self, nodes: &[TensorId], output: &mut Vec<TensorIr>) {
2469        if self.status == TensorStatus::ReadWrite && nodes.contains(&self.id) {
2470            output.push(self.clone());
2471            self.status = TensorStatus::ReadOnly;
2472        }
2473    }
2474}
2475
2476impl core::hash::Hash for RandomOpIr {
2477    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
2478        self.out.hash(state);
2479
2480        match self.distribution {
2481            Distribution::Default => 1u8.hash(state),
2482            Distribution::Bernoulli(_) => 2u8.hash(state),
2483            Distribution::Uniform(_, _) => 3u8.hash(state),
2484            Distribution::Normal(_, _) => 4u8.hash(state),
2485        }
2486    }
2487}
2488
2489impl core::hash::Hash for ScalarOpIr {
2490    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
2491        self.lhs.hash(state);
2492        self.out.hash(state);
2493    }
2494}
2495
2496impl core::hash::Hash for MaskFillOpIr {
2497    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
2498        self.tensor.hash(state);
2499        self.mask.hash(state);
2500        self.out.hash(state);
2501    }
2502}
2503
2504impl core::hash::Hash for ClampOpIr {
2505    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
2506        self.tensor.hash(state);
2507        self.out.hash(state);
2508    }
2509}
2510
2511impl core::hash::Hash for NumericOperationIr {
2512    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
2513        match self {
2514            NumericOperationIr::Add(repr) => repr.hash(state),
2515            NumericOperationIr::AddScalar(repr) => repr.hash(state),
2516            NumericOperationIr::Sub(repr) => repr.hash(state),
2517            NumericOperationIr::SubScalar(repr) => repr.hash(state),
2518            NumericOperationIr::Div(repr) => repr.hash(state),
2519            NumericOperationIr::DivScalar(repr) => repr.hash(state),
2520            NumericOperationIr::Rem(repr) => repr.hash(state),
2521            NumericOperationIr::RemScalar(repr) => repr.hash(state),
2522            NumericOperationIr::Mul(repr) => repr.hash(state),
2523            NumericOperationIr::MulScalar(repr) => repr.hash(state),
2524            NumericOperationIr::Abs(repr) => repr.hash(state),
2525            NumericOperationIr::Ones(repr) => repr.hash(state),
2526            NumericOperationIr::Zeros(repr) => repr.hash(state),
2527            NumericOperationIr::Full(repr) => repr.0.hash(state),
2528            NumericOperationIr::Gather(repr) => repr.hash(state),
2529            NumericOperationIr::Scatter(repr) => repr.hash(state),
2530            NumericOperationIr::Select(repr) => repr.hash(state),
2531            NumericOperationIr::SelectAssign(repr) => repr.hash(state),
2532            NumericOperationIr::MaskWhere(repr) => repr.hash(state),
2533            NumericOperationIr::MaskFill(repr) => repr.hash(state),
2534            NumericOperationIr::MeanDim(repr) => repr.hash(state),
2535            NumericOperationIr::Mean(repr) => repr.hash(state),
2536            NumericOperationIr::Sum(repr) => repr.hash(state),
2537            NumericOperationIr::SumDim(repr) => repr.hash(state),
2538            NumericOperationIr::Prod(repr) => repr.hash(state),
2539            NumericOperationIr::ProdDim(repr) => repr.hash(state),
2540            NumericOperationIr::EqualElem(repr) => repr.hash(state),
2541            NumericOperationIr::Greater(repr) => repr.hash(state),
2542            NumericOperationIr::GreaterElem(repr) => repr.hash(state),
2543            NumericOperationIr::GreaterEqual(repr) => repr.hash(state),
2544            NumericOperationIr::GreaterEqualElem(repr) => repr.hash(state),
2545            NumericOperationIr::Lower(repr) => repr.hash(state),
2546            NumericOperationIr::LowerElem(repr) => repr.hash(state),
2547            NumericOperationIr::LowerEqual(repr) => repr.hash(state),
2548            NumericOperationIr::LowerEqualElem(repr) => repr.hash(state),
2549            NumericOperationIr::ArgMax(repr) => repr.hash(state),
2550            NumericOperationIr::ArgMin(repr) => repr.hash(state),
2551            NumericOperationIr::Max(repr) => repr.hash(state),
2552            NumericOperationIr::MaxDimWithIndices(repr) => repr.hash(state),
2553            NumericOperationIr::MinDimWithIndices(repr) => repr.hash(state),
2554            NumericOperationIr::Min(repr) => repr.hash(state),
2555            NumericOperationIr::MaxDim(repr) => repr.hash(state),
2556            NumericOperationIr::MinDim(repr) => repr.hash(state),
2557            NumericOperationIr::MaxAbs(repr) => repr.hash(state),
2558            NumericOperationIr::MaxAbsDim(repr) => repr.hash(state),
2559            NumericOperationIr::Clamp(repr) => repr.hash(state),
2560            NumericOperationIr::IntRandom(repr) => repr.hash(state),
2561            NumericOperationIr::Powf(repr) => repr.hash(state),
2562        }
2563    }
2564}