burn_tensor/repr/
operation.rs

1use core::hash::Hash;
2use core::ops::Range;
3use serde::{Deserialize, Serialize};
4
5use alloc::borrow::ToOwned;
6use alloc::boxed::Box;
7use alloc::{string::String, vec, vec::Vec};
8
9use crate::{
10    ops::{
11        ConvOptions, ConvTransposeOptions, DeformConvOptions, InterpolateMode, InterpolateOptions,
12    },
13    quantization::QuantizationScheme,
14    repr::tensor::TensorDescription,
15    DType, Distribution, Element,
16};
17
18/// Custom operation in fusion stream, declaring it's inputs and outputs.
19#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
20pub struct CustomOpDescription {
21    /// Unique identifier of the operation.
22    pub id: String,
23    /// Input tensors used in this the custom operation.
24    pub inputs: Vec<TensorDescription>,
25    /// Output tensors used in this the custom operation.
26    pub outputs: Vec<TensorDescription>,
27}
28
29impl CustomOpDescription {
30    /// Create a new custom operation description.
31    pub fn new(
32        id: &'static str,
33        inputs: &[TensorDescription],
34        outputs: &[TensorDescription],
35    ) -> Self {
36        Self {
37            id: id.to_owned(),
38            inputs: inputs.to_vec(),
39            outputs: outputs.to_vec(),
40        }
41    }
42
43    /// Consume the description, and get the in and output tensors.
44    pub fn consume<const N_IN: usize, const N_OUT: usize>(
45        self,
46    ) -> ([TensorDescription; N_IN], [TensorDescription; N_OUT]) {
47        (
48            self.inputs.try_into().expect(
49                "Wrong number of inputs expected (expected {D}, is {}), check your implementation",
50            ),
51            self.outputs.try_into().expect(
52                "Wrong number of outputs expected (expected {D}, is {}), check your implementation",
53            ),
54        )
55    }
56
57    fn nodes(&self) -> Vec<&TensorDescription> {
58        self.inputs.iter().chain(self.outputs.iter()).collect()
59    }
60}
61
62/// Describe all tensor operations possible.
63#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
64pub enum OperationDescription {
65    /// Basic operation on a float tensor.
66    BaseFloat(BaseOperationDescription),
67    /// Basic operation on an int tensor.
68    BaseInt(BaseOperationDescription),
69    /// Basic operation on a bool tensor.
70    BaseBool(BaseOperationDescription),
71    /// Numeric operation on a float tensor.
72    NumericFloat(DType, NumericOperationDescription<f32>),
73    /// Numeric operation on an int tensor.
74    NumericInt(DType, NumericOperationDescription<i32>),
75    /// Operation specific to a bool tensor.
76    Bool(BoolOperationDescription),
77    /// Operation specific to an int tensor.
78    Int(IntOperationDescription),
79    /// Operation specific to a float tensor.
80    Float(DType, FloatOperationDescription),
81    /// Module operation.
82    Module(ModuleOperationDescription),
83    /// A custom operation.
84    Custom(CustomOpDescription),
85}
86
87/// Operation description specific to a float tensor.
88#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
89pub enum FloatOperationDescription {
90    /// Operation corresponding to [exp](crate::ops::FloatTensorOps::float_exp).
91    Exp(UnaryOperationDescription),
92    /// Operation corresponding to [log](crate::ops::FloatTensorOps::float_log).
93    Log(UnaryOperationDescription),
94    /// Operation corresponding to [log1p](crate::ops::FloatTensorOps::float_log1p).
95    Log1p(UnaryOperationDescription),
96    /// Operation corresponding to [erf](crate::ops::FloatTensorOps::float_erf).
97    Erf(UnaryOperationDescription),
98    /// Operation corresponding to [powf_scalar](crate::ops::FloatTensorOps::float_powf_scalar).
99    PowfScalar(ScalarOperationDescription<f32>),
100    /// Operation corresponding to [sqrt](crate::ops::FloatTensorOps::float_sqrt).
101    Sqrt(UnaryOperationDescription),
102    /// Operation corresponding to [cos](crate::ops::FloatTensorOps::float_cos).
103    Cos(UnaryOperationDescription),
104    /// Operation corresponding to [sin](crate::ops::FloatTensorOps::float_sin).
105    Sin(UnaryOperationDescription),
106    /// Operation corresponding to [tanh](crate::ops::FloatTensorOps::float_tanh).
107    Tanh(UnaryOperationDescription),
108    /// Operation corresponding to [round](crate::ops::FloatTensorOps::float_round).
109    Round(UnaryOperationDescription),
110    /// Operation corresponding to [floor](crate::ops::FloatTensorOps::float_floor).
111    Floor(UnaryOperationDescription),
112    /// Operation corresponding to [ceil](crate::ops::FloatTensorOps::float_ceil).
113    Ceil(UnaryOperationDescription),
114    /// Operation corresponding to [into_int](crate::ops::FloatTensorOps::float_into_int).
115    IntoInt(UnaryOperationDescription),
116    /// Operation corresponding to [matmul](crate::ops::FloatTensorOps::float_matmul).
117    Matmul(BinaryOperationDescription),
118    /// Operation corresponding to [random](crate::ops::FloatTensorOps::float_random).
119    Random(RandomOperationDescription),
120    /// Operation corresponding to [recip](crate::ops::FloatTensorOps::float_recip).
121    Recip(UnaryOperationDescription),
122    /// Operation corresponding to [quantize](crate::ops::QTensorOps::quantize).
123    Quantize(QuantizeOperationDescription),
124    /// Operation corresponding to [dequantize](crate::ops::QTensorOps::dequantize).
125    Dequantize(DequantizeOperationDescription),
126}
127
128/// Operation description specific to module.
129#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
130pub enum ModuleOperationDescription {
131    /// Operation corresponding to [embedding](crate::ops::ModuleOps::embedding).
132    Embedding(EmbeddingDescription),
133    /// Operation corresponding to [embedding_backward](crate::ops::ModuleOps::embedding_backward).
134    EmbeddingBackward(EmbeddingBackwardDescription),
135    /// Operation corresponding to [conv1d](crate::ops::ModuleOps::conv1d).
136    Conv1d(Conv1dDescription),
137    /// Operation corresponding to [conv2d](crate::ops::ModuleOps::conv2d).
138    Conv2d(Conv2dDescription),
139    /// Operation corresponding to [conv3d](crate::ops::ModuleOps::conv3d).
140    Conv3d(Conv3dDescription),
141    /// Operation corresponding to [deform_conv2d](crate::ops::ModuleOps::deform_conv2d)
142    DeformableConv2d(Box<DeformConv2dDescription>),
143    /// Operation corresponding to [deform_conv2d_backward](crate::ops::ModuleOps::deform_conv2d_backward)
144    DeformableConv2dBackward(Box<DeformConv2dBackwardDescription>),
145    /// Operation corresponding to [conv transpose 1d](crate::ops::ModuleOps::conv_transpose1d).
146    ConvTranspose1d(ConvTranspose1dDescription),
147    /// Operation corresponding to [conv transpose 2d](crate::ops::ModuleOps::conv_transpose2d).
148    ConvTranspose2d(ConvTranspose2dDescription),
149    /// Operation corresponding to [conv transpose 3d](crate::ops::ModuleOps::conv_transpose3d).
150    ConvTranspose3d(ConvTranspose3dDescription),
151    /// Operation corresponding to [avg pool 1d](crate::ops::ModuleOps::avg_pool1d).
152    AvgPool1d(AvgPool1dDescription),
153    /// Operation corresponding to [avg pool 2d](crate::ops::ModuleOps::avg_pool2d).
154    AvgPool2d(AvgPool2dDescription),
155    /// Operation corresponding to
156    /// [avg pool 1d backward](crate::ops::ModuleOps::avg_pool1d_backward).
157    AvgPool1dBackward(AvgPool1dBackwardDescription),
158    /// Operation corresponding to
159    /// [avg pool 2d backward](crate::ops::ModuleOps::avg_pool2d_backward).
160    AvgPool2dBackward(AvgPool2dBackwardDescription),
161    /// Operation corresponding to
162    /// [adaptive avg pool 1d](crate::ops::ModuleOps::adaptive_avg_pool1d).
163    AdaptiveAvgPool1d(AdaptiveAvgPool1dDescription),
164    /// Operation corresponding to
165    /// [adaptive avg pool 2d](crate::ops::ModuleOps::adaptive_avg_pool2d).
166    AdaptiveAvgPool2d(AdaptiveAvgPool2dDescription),
167    /// Operation corresponding to
168    /// [adaptive avg pool 1d backward](crate::ops::ModuleOps::adaptive_avg_pool1d_backward).
169    AdaptiveAvgPool1dBackward(AdaptiveAvgPool1dBackwardDescription),
170    /// Operation corresponding to
171    /// [adaptive avg pool 2d backward](crate::ops::ModuleOps::adaptive_avg_pool2d_backward).
172    AdaptiveAvgPool2dBackward(AdaptiveAvgPool2dBackwardDescription),
173    /// Operation corresponding to
174    /// [max pool 1d](crate::ops::ModuleOps::max_pool1d).
175    MaxPool1d(MaxPool1dDescription),
176    /// Operation corresponding to
177    /// [max pool 1d with indices](crate::ops::ModuleOps::max_pool1d_with_indices).
178    MaxPool1dWithIndices(MaxPool1dWithIndicesDescription),
179    /// Operation corresponding to
180    /// [max pool 1d with indices backward](crate::ops::ModuleOps::max_pool1d_with_indices_backward).
181    MaxPool1dWithIndicesBackward(MaxPool1dWithIndicesBackwardDescription),
182    /// Operation corresponding to
183    /// [max pool 2d](crate::ops::ModuleOps::max_pool1d).
184    MaxPool2d(MaxPool2dDescription),
185    /// Operation corresponding to
186    /// [max pool 2d with indices](crate::ops::ModuleOps::max_pool2d_with_indices).
187    MaxPool2dWithIndices(MaxPool2dWithIndicesDescription),
188    /// Operation corresponding to
189    /// [max pool 2d with indices backward](crate::ops::ModuleOps::max_pool2d_with_indices_backward).
190    MaxPool2dWithIndicesBackward(MaxPool2dWithIndicesBackwardDescription),
191    /// Operation corresponding to [interpolate](crate::ops::ModuleOps::interpolate).
192    Interpolate(InterpolateDescription),
193    /// Operation corresponding to [interpolate backward](crate::ops::ModuleOps::interpolate_backward).
194    InterpolateBackward(InterpolateBackwardDescription),
195}
196
197/// Basic operations that can be done on any tensor type.
198#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
199pub enum BaseOperationDescription {
200    /// Operation corresponding to:
201    ///
202    /// Float => [to device](crate::ops::FloatTensorOps::float_to_device).
203    /// Int => [to device](crate::ops::IntTensorOps::int_to_device).
204    /// Bool => [to device](crate::ops::BoolTensorOps::bool_to_device).
205    ToDevice(TensorDescription),
206    /// Operation corresponding to:
207    ///
208    /// Float => [reshape](crate::ops::FloatTensorOps::float_reshape).
209    /// Int => [reshape](crate::ops::IntTensorOps::int_reshape).
210    /// Bool => [reshape](crate::ops::BoolTensorOps::bool_reshape).
211    Reshape(ReshapeDescription),
212
213    /// Operation corresponding to:
214    ///
215    /// Float => [swap_dims](crate::ops::FloatTensorOps::float_swap_dims).
216    /// Int => [swap_dims](crate::ops::IntTensorOps::int_swap_dims).
217    /// Bool => [swap_dims](crate::ops::BoolTensorOps::bool_swap_dims).
218    SwapDims(SwapDimsDescription),
219
220    /// Operation corresponding to:
221    ///
222    /// Float => [permute](crate::ops::FloatTensorOps::float_permute).
223    /// Int => [permute](crate::ops::IntTensorOps::int_permute).
224    /// Bool => [permute](crate::ops::BoolTensorOps::bool_permute).
225    Permute(PermuteOperationDescription),
226
227    /// Operation corresponding to:
228    /// Float => [flip](crate::ops::FloatTensorOps::float_flip).
229    /// Int => [flip](crate::ops::IntTensorOps::int_flip).
230    /// Bool => [flip](crate::ops::BoolTensorOps::bool_flip).
231    Flip(FlipOperationDescription),
232
233    /// Operation corresponding to:
234    ///
235    /// Float => [expand](crate::ops::FloatTensorOps::float_expand).
236    /// Int => [expand](crate::ops::IntTensorOps::int_expand).
237    /// Bool => [expand](crate::ops::BoolTensorOps::bool_expand).
238    Expand(ExpandOperationDescription),
239
240    /// Operation corresponding to:
241    ///
242    /// Float => [slice](crate::ops::FloatTensorOps::float_slice).
243    /// Int => [slice](crate::ops::IntTensorOps::int_slice).
244    /// Bool => [slice](crate::ops::BoolTensorOps::bool_slice).
245    Slice(SliceOperationDescription),
246    /// Operation corresponding to:
247    ///
248    /// Float => [slice assign](crate::ops::FloatTensorOps::float_slice_assign).
249    /// Int => [slice assign](crate::ops::IntTensorOps::int_slice_assign).
250    /// Bool => [slice assign](crate::ops::BoolTensorOps::bool_slice_assign).
251    SliceAssign(SliceAssignOperationDescription),
252    /// Operation corresponding to:
253    ///
254    /// Float => [equal](crate::ops::FloatTensorOps::float_equal).
255    /// Int => [equal](crate::ops::IntTensorOps::int_equal).
256    /// Bool => [equal](crate::ops::BoolTensorOps::bool_equal).
257    Equal(BinaryOperationDescription),
258    /// Operation corresponding to:
259    ///
260    /// Float => [repeat dim](crate::ops::FloatTensorOps::float_repeat_dim).
261    /// Int => [repeat dim](crate::ops::IntTensorOps::int_repeat_dim).
262    /// Bool => [repeat dim](crate::ops::BoolTensorOps::bool_repeat_dim).
263    RepeatDim(RepeatDimOperationDescription),
264    /// Operation corresponding to:
265    ///
266    /// Float => [cat](crate::ops::FloatTensorOps::float_cat).
267    /// Int => [cat](crate::ops::IntTensorOps::int_cat).
268    /// Bool => [cat](crate::ops::BoolTensorOps::bool_cat).
269    Cat(CatOperationDescription),
270    /// Cast operation, no direct operation and should be supported by fusion backend.
271    Cast(UnaryOperationDescription),
272
273    /// Operation corresponding to:
274    ///
275    /// Float => [equal](crate::ops::FloatTensorOps::float_empty).
276    /// Int => [equal](crate::ops::IntTensorOps::int_empty).
277    /// Bool => [equal](crate::ops::BoolTensorOps::bool_empty).
278    Empty(TensorDescription),
279}
280
281/// Numeric operations on int and float tensors.
282#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
283pub enum NumericOperationDescription<E> {
284    /// Operation corresponding to:
285    ///
286    /// Float => [add](crate::ops::FloatTensorOps::float_add).
287    /// Int => [add](crate::ops::IntTensorOps::int_add).
288    Add(BinaryOperationDescription),
289    /// Operation corresponding to:
290    ///
291    /// Float => [add scalar](crate::ops::FloatTensorOps::float_add_scalar).
292    /// Int => [add scalar](crate::ops::IntTensorOps::int_add_scalar).
293    AddScalar(ScalarOperationDescription<E>),
294    /// Operation corresponding to:
295    ///
296    /// Float => [sub](crate::ops::FloatTensorOps::float_sub).
297    /// Int => [sub](crate::ops::IntTensorOps::int_sub).
298    Sub(BinaryOperationDescription),
299    /// Operation corresponding to:
300    ///
301    /// Float => [sub scalar](crate::ops::FloatTensorOps::float_sub_scalar).
302    /// Int => [sub scalar](crate::ops::IntTensorOps::int_sub_scalar).
303    SubScalar(ScalarOperationDescription<E>),
304    /// Operation corresponding to:
305    ///
306    /// Float => [div](crate::ops::FloatTensorOps::float_div).
307    /// Int => [div](crate::ops::IntTensorOps::int_div).
308    Div(BinaryOperationDescription),
309    /// Operation corresponding to:
310    ///
311    /// Float => [div scalar](crate::ops::FloatTensorOps::float_div_scalar).
312    /// Int => [div scalar](crate::ops::IntTensorOps::int_div_scalar).
313    DivScalar(ScalarOperationDescription<E>),
314    /// Operation corresponding to:
315    ///
316    /// Float => [rem](crate::ops::FloatTensorOps::float_remainder).
317    /// Int => [rem](crate::ops::IntTensorOps::int_remainder).
318    Rem(BinaryOperationDescription),
319    /// Operation corresponding to:
320    ///
321    /// Float => [rem scalar](crate::ops::FloatTensorOps::float_remainder_scalar).
322    /// Int => [rem scalar](crate::ops::IntTensorOps::int_remainder_scalar).
323    RemScalar(ScalarOperationDescription<E>),
324    /// Operation corresponding to:
325    ///
326    /// Float => [mul](crate::ops::FloatTensorOps::float_mul).
327    /// Int => [mul](crate::ops::IntTensorOps::int_mul).
328    Mul(BinaryOperationDescription),
329    /// Operation corresponding to:
330    ///
331    /// Float => [mul scalar](crate::ops::FloatTensorOps::float_mul_scalar).
332    /// Int => [mul scalar](crate::ops::IntTensorOps::int_mul_scalar).
333    MulScalar(ScalarOperationDescription<E>),
334    /// Operation corresponding to:
335    ///
336    /// Float => [abs](crate::ops::FloatTensorOps::float_abs).
337    /// Int => [abs](crate::ops::IntTensorOps::int_abs).
338    Abs(UnaryOperationDescription),
339    /// Operation corresponding to:
340    ///
341    /// Float => [ones](crate::ops::FloatTensorOps::float_ones).
342    /// Int => [ones](crate::ops::IntTensorOps::int_ones).
343    Ones(TensorDescription),
344    /// Operation corresponding to:
345    ///
346    /// Float => [zeros](crate::ops::FloatTensorOps::float_zeros).
347    /// Int => [zeros](crate::ops::IntTensorOps::int_zeros).
348    Zeros(TensorDescription),
349    /// Operation corresponding to:
350    ///
351    /// Float => [full](crate::ops::FloatTensorOps::float_full).
352    /// Int => [full](crate::ops::IntTensorOps::int_full).
353    Full((TensorDescription, E)),
354    /// Operation corresponding to:
355    ///
356    /// Float => [gather](crate::ops::FloatTensorOps::float_gather).
357    /// Int => [gather](crate::ops::IntTensorOps::int_gather).
358    Gather(GatherOperationDescription),
359    /// Operation corresponding to:
360    ///
361    /// Float => [scatter](crate::ops::FloatTensorOps::float_scatter).
362    /// Int => [scatter](crate::ops::IntTensorOps::int_scatter).
363    Scatter(ScatterOperationDescription),
364    /// Operation corresponding to:
365    ///
366    /// Float => [select](crate::ops::FloatTensorOps::float_select).
367    /// Int => [select](crate::ops::IntTensorOps::int_select).
368    Select(SelectOperationDescription),
369    /// Operation corresponding to:
370    ///
371    /// Float => [select assign](crate::ops::FloatTensorOps::float_select_assign).
372    /// Int => [select assign](crate::ops::IntTensorOps::int_select_assign).
373    SelectAssign(SelectAssignOperationDescription),
374    /// Operation corresponding to:
375    ///
376    /// Float => [mask where](crate::ops::FloatTensorOps::float_mask_where).
377    /// Int => [mask where](crate::ops::IntTensorOps::int_mask_where).
378    MaskWhere(MaskWhereOperationDescription),
379    /// Operation corresponding to:
380    ///
381    /// Float => [mask fill](crate::ops::FloatTensorOps::float_mask_fill).
382    /// Int => [mask fill](crate::ops::IntTensorOps::int_mask_fill).
383    MaskFill(MaskFillOperationDescription<E>),
384    /// Operation corresponding to:
385    ///
386    /// Float => [mean dim](crate::ops::FloatTensorOps::float_mean_dim).
387    /// Int => [mean dim](crate::ops::IntTensorOps::int_mean_dim).
388    MeanDim(ScalarOperationDescription<usize>),
389    /// Operation corresponding to:
390    ///
391    /// Float => [mean](crate::ops::FloatTensorOps::float_mean).
392    /// Int => [mean](crate::ops::IntTensorOps::int_mean).
393    Mean(UnaryOperationDescription),
394    /// Operation corresponding to:
395    ///
396    /// Float => [sum](crate::ops::FloatTensorOps::float_sum).
397    /// Int => [sum](crate::ops::IntTensorOps::int_sum).
398    Sum(UnaryOperationDescription),
399    /// Operation corresponding to:
400    ///
401    /// Float => [sum dim](crate::ops::FloatTensorOps::float_sum_dim).
402    /// Int => [sum dim](crate::ops::IntTensorOps::int_sum_dim).
403    SumDim(ScalarOperationDescription<usize>),
404
405    /// Operation corresponding to:
406    ///
407    /// Float => [prod](crate::ops::FloatTensorOps::float_prod).
408    /// Int => [prod](crate::ops::IntTensorOps::int_prod).
409    Prod(UnaryOperationDescription),
410
411    /// Operation corresponding to:
412    ///
413    /// Float => [prod dim](crate::ops::FloatTensorOps::float_prod_dim).
414    /// Int => [prod dim](crate::ops::IntTensorOps::int_prod_dim).
415    ProdDim(ScalarOperationDescription<usize>),
416
417    /// Operation corresponding to:
418    ///
419    /// Float => [equal elem](crate::ops::FloatTensorOps::float_equal_elem).
420    /// Int => [equal elem](crate::ops::IntTensorOps::int_equal_elem).
421    EqualElem(ScalarOperationDescription<E>),
422    /// Operation corresponding to:
423    ///
424    /// Float => [greater](crate::ops::FloatTensorOps::float_greater).
425    /// Int => [greater](crate::ops::IntTensorOps::int_greater).
426    Greater(BinaryOperationDescription),
427    /// Operation corresponding to:
428    ///
429    /// Float => [greater elem](crate::ops::FloatTensorOps::float_greater_elem).
430    /// Int => [greater elem](crate::ops::IntTensorOps::int_greater_elem).
431    GreaterElem(ScalarOperationDescription<E>),
432    /// Operation corresponding to:
433    ///
434    /// Float => [greater equal](crate::ops::FloatTensorOps::float_greater_elem).
435    /// Int => [greater elem](crate::ops::IntTensorOps::int_greater_elem).
436    GreaterEqual(BinaryOperationDescription),
437    /// Operation corresponding to:
438    ///
439    /// Float => [greater equal elem](crate::ops::FloatTensorOps::float_greater_equal_elem).
440    /// Int => [greater equal elem](crate::ops::IntTensorOps::int_greater_equal_elem).
441    GreaterEqualElem(ScalarOperationDescription<E>),
442    /// Operation corresponding to:
443    ///
444    /// Float => [lower](crate::ops::FloatTensorOps::float_lower).
445    /// Int => [lower](crate::ops::IntTensorOps::int_lower).
446    Lower(BinaryOperationDescription),
447    /// Operation corresponding to:
448    ///
449    /// Float => [lower elem](crate::ops::FloatTensorOps::float_lower_elem).
450    /// Int => [lower elem](crate::ops::IntTensorOps::int_lower_elem).
451    LowerElem(ScalarOperationDescription<E>),
452    /// Operation corresponding to:
453    ///
454    /// Float => [lower equal](crate::ops::FloatTensorOps::float_lower_equal).
455    /// Int => [lower equal](crate::ops::IntTensorOps::int_lower_equal).
456    LowerEqual(BinaryOperationDescription),
457    /// Operation corresponding to:
458    ///
459    /// Float => [lower equal elem](crate::ops::FloatTensorOps::float_lower_equal_elem).
460    /// Int => [lower equal elem](crate::ops::IntTensorOps::int_lower_equal_elem).
461    LowerEqualElem(ScalarOperationDescription<E>),
462    /// Operation corresponding to:
463    ///
464    /// Float => [argmax](crate::ops::FloatTensorOps::float_argmax).
465    /// Int => [argmax](crate::ops::IntTensorOps::int_argmax).
466    ArgMax(ScalarOperationDescription<usize>),
467    /// Operation corresponding to:
468    ///
469    /// Float => [argmin](crate::ops::FloatTensorOps::float_argmin).
470    /// Int => [argmin](crate::ops::IntTensorOps::int_argmin).
471    ArgMin(ScalarOperationDescription<usize>),
472    /// Operation corresponding to:
473    ///
474    /// Float => [max](crate::ops::FloatTensorOps::float_max).
475    /// Int => [max](crate::ops::IntTensorOps::int_max).
476    Max(UnaryOperationDescription),
477    /// Operation corresponding to:
478    ///
479    /// Float => [max dim with indices](crate::ops::FloatTensorOps::float_max_dim_with_indices).
480    /// Int => [max dim with indices](crate::ops::IntTensorOps::int_max_dim_with_indices).
481    MaxDimWithIndices(ReduceDimWithIndicesDescription),
482    /// Operation corresponding to:
483    ///
484    /// Float => [min dim with indices](crate::ops::FloatTensorOps::float_min_dim_with_indices).
485    /// Int => [min dim with indices](crate::ops::IntTensorOps::int_min_dim_with_indices).
486    MinDimWithIndices(ReduceDimWithIndicesDescription),
487    /// Operation corresponding to:
488    ///
489    /// Float => [min](crate::ops::FloatTensorOps::float_min).
490    /// Int => [min](crate::ops::IntTensorOps::int_min).
491    Min(UnaryOperationDescription),
492    /// Operation corresponding to:
493    ///
494    /// Float => [max dim](crate::ops::FloatTensorOps::float_max_dim).
495    /// Int => [max dim](crate::ops::IntTensorOps::int_max_dim).
496    MaxDim(ScalarOperationDescription<usize>),
497    /// Operation corresponding to:
498    ///
499    /// Float => [min dim](crate::ops::FloatTensorOps::float_min_dim).
500    /// Int => [min dim](crate::ops::IntTensorOps::int_min_dim).
501    MinDim(ScalarOperationDescription<usize>),
502    /// Operation corresponding to:
503    ///
504    /// Float => [clamp](crate::ops::FloatTensorOps::float_clamp).
505    /// Int => [clamp](crate::ops::IntTensorOps::int_clamp).
506    Clamp(ClampOperationDescription<E>),
507    /// Operation corresponding to:
508    ///
509    /// Int => [random](crate::ops::IntTensorOps::int_random).
510    IntRandom(RandomOperationDescription),
511    /// Operation corresponding to:
512    ///
513    /// Float => [powf](crate::ops::FloatTensorOps::float_powf).
514    /// Int => [powf](crate::ops::IntTensorOps::int_powf).
515    Powf(BinaryOperationDescription),
516}
517
518/// Operation description specific to an int tensor.
519#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
520pub enum IntOperationDescription {
521    /// Operation corresponding to [into float](crate::ops::IntTensorOps::int_into_float).
522    IntoFloat(UnaryOperationDescription),
523}
524
525/// Operation description specific to a bool tensor.
526#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
527pub enum BoolOperationDescription {
528    /// Operation corresponding to [into float](crate::ops::BoolTensorOps::bool_into_float).
529    IntoFloat(UnaryOperationDescription),
530    /// Operation corresponding to [into int](crate::ops::BoolTensorOps::bool_into_int).
531    IntoInt(UnaryOperationDescription),
532    /// Operation corresponding to [not](crate::ops::BoolTensorOps::bool_not).
533    Not(UnaryOperationDescription),
534}
535
536/// Swap dim operation description.
537#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
538pub struct SwapDimsDescription {
539    /// Input tensor description.
540    pub input: TensorDescription,
541    /// Output tensor description.
542    pub out: TensorDescription,
543    /// The first dim to swap.
544    pub dim1: usize,
545    /// The second dim to swap.
546    pub dim2: usize,
547}
548
549/// Permute operation description.
550#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
551pub struct PermuteOperationDescription {
552    /// Input tensor description.
553    pub input: TensorDescription,
554    /// Output tensor description.
555    pub out: TensorDescription,
556    /// The new order of the dimensions.
557    pub axes: Vec<usize>,
558}
559
560/// Expand operation description.
561#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
562pub struct ExpandOperationDescription {
563    /// Input tensor description.
564    pub input: TensorDescription,
565    /// Output tensor description.
566    pub out: TensorDescription,
567    /// The new shape.
568    pub shape: Vec<usize>,
569}
570
571/// Flip operation description.
572#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
573pub struct FlipOperationDescription {
574    /// Input tensor description.
575    pub input: TensorDescription,
576    /// Output tensor description.
577    pub out: TensorDescription,
578    /// The dimensions to flip.
579    pub axes: Vec<usize>,
580}
581
582#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
583#[allow(missing_docs)]
584pub struct RandomOperationDescription {
585    pub out: TensorDescription,
586    pub distribution: Distribution,
587}
588
589#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
590#[allow(missing_docs)]
591pub struct ReshapeDescription {
592    pub input: TensorDescription,
593    pub out: TensorDescription,
594}
595
596#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
597#[allow(missing_docs)]
598pub struct ExpandDescription {
599    pub input: TensorDescription,
600    pub out: TensorDescription,
601}
602
603#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
604#[allow(missing_docs)]
605pub struct BinaryOperationDescription {
606    pub lhs: TensorDescription,
607    pub rhs: TensorDescription,
608    pub out: TensorDescription,
609}
610
611#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
612#[allow(missing_docs)]
613pub struct UnaryOperationDescription {
614    pub input: TensorDescription,
615    pub out: TensorDescription,
616}
617
618#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
619#[allow(missing_docs)]
620pub struct ScalarOperationDescription<E> {
621    pub lhs: TensorDescription,
622    pub rhs: E,
623    pub out: TensorDescription,
624}
625
626#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
627#[allow(missing_docs)]
628pub struct GatherOperationDescription {
629    pub tensor: TensorDescription,
630    pub dim: usize,
631    pub indices: TensorDescription,
632    pub out: TensorDescription,
633}
634
635#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
636#[allow(missing_docs)]
637pub struct ScatterOperationDescription {
638    pub tensor: TensorDescription,
639    pub dim: usize,
640    pub indices: TensorDescription,
641    pub value: TensorDescription,
642    pub out: TensorDescription,
643}
644
645#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
646#[allow(missing_docs)]
647pub struct SelectOperationDescription {
648    pub tensor: TensorDescription,
649    pub dim: usize,
650    pub indices: TensorDescription,
651    pub out: TensorDescription,
652}
653
654#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
655#[allow(missing_docs)]
656pub struct SelectAssignOperationDescription {
657    pub tensor: TensorDescription,
658    pub dim: usize,
659    pub indices: TensorDescription,
660    pub value: TensorDescription,
661    pub out: TensorDescription,
662}
663
664#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
665#[allow(missing_docs)]
666pub struct SliceOperationDescription {
667    pub tensor: TensorDescription,
668    pub ranges: Vec<Range<usize>>,
669    pub out: TensorDescription,
670}
671
672#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
673#[allow(missing_docs)]
674pub struct SliceAssignOperationDescription {
675    pub tensor: TensorDescription,
676    pub ranges: Vec<Range<usize>>,
677    pub value: TensorDescription,
678    pub out: TensorDescription,
679}
680
681#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
682#[allow(missing_docs)]
683pub struct MaskWhereOperationDescription {
684    pub tensor: TensorDescription,
685    pub mask: TensorDescription,
686    pub value: TensorDescription,
687    pub out: TensorDescription,
688}
689
690#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
691#[allow(missing_docs)]
692pub struct MaskFillOperationDescription<E> {
693    pub tensor: TensorDescription,
694    pub mask: TensorDescription,
695    pub value: E,
696    pub out: TensorDescription,
697}
698
699#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
700#[allow(missing_docs)]
701pub struct ClampOperationDescription<E> {
702    pub tensor: TensorDescription,
703    pub min: E,
704    pub max: E,
705    pub out: TensorDescription,
706}
707
708#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
709#[allow(missing_docs)]
710pub struct RepeatDimOperationDescription {
711    pub tensor: TensorDescription,
712    pub dim: usize,
713    pub times: usize,
714    pub out: TensorDescription,
715}
716
717#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
718#[allow(missing_docs)]
719pub struct CatOperationDescription {
720    pub tensors: Vec<TensorDescription>,
721    pub dim: usize,
722    pub out: TensorDescription,
723}
724
725#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
726#[allow(missing_docs)]
727pub struct ReduceDimWithIndicesDescription {
728    pub tensor: TensorDescription,
729    pub dim: usize,
730    pub out: TensorDescription,
731    pub out_indices: TensorDescription,
732}
733
734#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
735#[allow(missing_docs)]
736pub struct EmbeddingDescription {
737    pub weights: TensorDescription,
738    pub indices: TensorDescription,
739    pub out: TensorDescription,
740}
741
742#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
743#[allow(missing_docs)]
744pub struct EmbeddingBackwardDescription {
745    pub weights: TensorDescription,
746    pub out_grad: TensorDescription,
747    pub indices: TensorDescription,
748    pub out: TensorDescription,
749}
750
751#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
752#[allow(missing_docs)]
753pub struct Conv1dDescription {
754    pub x: TensorDescription,
755    pub weight: TensorDescription,
756    pub bias: Option<TensorDescription>,
757    pub options: Conv1dOptionsDescription,
758    pub out: TensorDescription,
759}
760
761#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
762#[allow(missing_docs)]
763pub struct Conv2dDescription {
764    pub x: TensorDescription,
765    pub weight: TensorDescription,
766    pub bias: Option<TensorDescription>,
767    pub options: Conv2dOptionsDescription,
768    pub out: TensorDescription,
769}
770
771#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
772#[allow(missing_docs)]
773pub struct DeformConv2dDescription {
774    pub x: TensorDescription,
775    pub offset: TensorDescription,
776    pub weight: TensorDescription,
777    pub mask: Option<TensorDescription>,
778    pub bias: Option<TensorDescription>,
779    pub options: DeformableConv2dOptionsDescription,
780    pub out: TensorDescription,
781}
782
783#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
784#[allow(missing_docs)]
785pub struct DeformConv2dBackwardDescription {
786    pub x: TensorDescription,
787    pub offset: TensorDescription,
788    pub weight: TensorDescription,
789    pub mask: Option<TensorDescription>,
790    pub bias: Option<TensorDescription>,
791    pub out_grad: TensorDescription,
792    pub options: DeformableConv2dOptionsDescription,
793    pub input_grad: TensorDescription,
794    pub offset_grad: TensorDescription,
795    pub weight_grad: TensorDescription,
796    pub mask_grad: Option<TensorDescription>,
797    pub bias_grad: Option<TensorDescription>,
798}
799
800#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
801#[allow(missing_docs)]
802pub struct Conv3dDescription {
803    pub x: TensorDescription,
804    pub weight: TensorDescription,
805    pub bias: Option<TensorDescription>,
806    pub options: Conv3dOptionsDescription,
807    pub out: TensorDescription,
808}
809
810#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
811#[allow(missing_docs)]
812pub struct ConvTranspose1dDescription {
813    pub x: TensorDescription,
814    pub weight: TensorDescription,
815    pub bias: Option<TensorDescription>,
816    pub options: ConvTranspose1dOptionsDescription,
817    pub out: TensorDescription,
818}
819
820#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
821#[allow(missing_docs)]
822pub struct ConvTranspose2dDescription {
823    pub x: TensorDescription,
824    pub weight: TensorDescription,
825    pub bias: Option<TensorDescription>,
826    pub options: ConvTranspose2dOptionsDescription,
827    pub out: TensorDescription,
828}
829
830#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
831#[allow(missing_docs)]
832pub struct ConvTranspose3dDescription {
833    pub x: TensorDescription,
834    pub weight: TensorDescription,
835    pub bias: Option<TensorDescription>,
836    pub options: ConvTranspose3dOptionsDescription,
837    pub out: TensorDescription,
838}
839
840#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
841#[allow(missing_docs)]
842pub struct Conv1dOptionsDescription {
843    pub stride: [usize; 1],
844    pub padding: [usize; 1],
845    pub dilation: [usize; 1],
846    pub groups: usize,
847}
848
849#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
850#[allow(missing_docs)]
851pub struct Conv2dOptionsDescription {
852    pub stride: [usize; 2],
853    pub padding: [usize; 2],
854    pub dilation: [usize; 2],
855    pub groups: usize,
856}
857
858#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
859#[allow(missing_docs)]
860pub struct DeformableConv2dOptionsDescription {
861    pub stride: [usize; 2],
862    pub padding: [usize; 2],
863    pub dilation: [usize; 2],
864    pub weight_groups: usize,
865    pub offset_groups: usize,
866}
867
868#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
869#[allow(missing_docs)]
870pub struct Conv3dOptionsDescription {
871    pub stride: [usize; 3],
872    pub padding: [usize; 3],
873    pub dilation: [usize; 3],
874    pub groups: usize,
875}
876
877#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
878#[allow(missing_docs)]
879pub struct ConvTranspose1dOptionsDescription {
880    pub stride: [usize; 1],
881    pub padding: [usize; 1],
882    pub padding_out: [usize; 1],
883    pub dilation: [usize; 1],
884    pub groups: usize,
885}
886
887#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
888#[allow(missing_docs)]
889pub struct ConvTranspose2dOptionsDescription {
890    pub stride: [usize; 2],
891    pub padding: [usize; 2],
892    pub padding_out: [usize; 2],
893    pub dilation: [usize; 2],
894    pub groups: usize,
895}
896
897#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
898#[allow(missing_docs)]
899pub struct ConvTranspose3dOptionsDescription {
900    pub stride: [usize; 3],
901    pub padding: [usize; 3],
902    pub padding_out: [usize; 3],
903    pub dilation: [usize; 3],
904    pub groups: usize,
905}
906
907/// Quantization parameters description.
908#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
909pub struct QuantizationParametersDescription {
910    /// The scaling factor.
911    pub scale: TensorDescription,
912    /// The zero-point offset.
913    pub offset: Option<TensorDescription>,
914}
915
916#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
917#[allow(missing_docs)]
918pub struct QuantizeOperationDescription {
919    pub tensor: TensorDescription,
920    pub qparams: QuantizationParametersDescription,
921    pub scheme: QuantizationScheme,
922    pub out: TensorDescription,
923}
924
925#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
926#[allow(missing_docs)]
927pub struct DequantizeOperationDescription {
928    pub input: TensorDescription,
929    pub out: TensorDescription,
930}
931
932impl From<ConvOptions<1>> for Conv1dOptionsDescription {
933    fn from(value: ConvOptions<1>) -> Self {
934        Self {
935            stride: value.stride,
936            padding: value.padding,
937            dilation: value.dilation,
938            groups: value.groups,
939        }
940    }
941}
942
943impl From<ConvOptions<2>> for Conv2dOptionsDescription {
944    fn from(value: ConvOptions<2>) -> Self {
945        Self {
946            stride: value.stride,
947            padding: value.padding,
948            dilation: value.dilation,
949            groups: value.groups,
950        }
951    }
952}
953
954impl From<ConvOptions<3>> for Conv3dOptionsDescription {
955    fn from(value: ConvOptions<3>) -> Self {
956        Self {
957            stride: value.stride,
958            padding: value.padding,
959            dilation: value.dilation,
960            groups: value.groups,
961        }
962    }
963}
964
965impl From<DeformConvOptions<2>> for DeformableConv2dOptionsDescription {
966    fn from(value: DeformConvOptions<2>) -> Self {
967        Self {
968            stride: value.stride,
969            padding: value.padding,
970            dilation: value.dilation,
971            weight_groups: value.weight_groups,
972            offset_groups: value.offset_groups,
973        }
974    }
975}
976
977impl From<ConvTransposeOptions<1>> for ConvTranspose1dOptionsDescription {
978    fn from(value: ConvTransposeOptions<1>) -> Self {
979        Self {
980            stride: value.stride,
981            padding: value.padding,
982            padding_out: value.padding_out,
983            dilation: value.dilation,
984            groups: value.groups,
985        }
986    }
987}
988
989impl From<ConvTransposeOptions<2>> for ConvTranspose2dOptionsDescription {
990    fn from(value: ConvTransposeOptions<2>) -> Self {
991        Self {
992            stride: value.stride,
993            padding: value.padding,
994            padding_out: value.padding_out,
995            dilation: value.dilation,
996            groups: value.groups,
997        }
998    }
999}
1000
1001impl From<ConvTransposeOptions<3>> for ConvTranspose3dOptionsDescription {
1002    fn from(value: ConvTransposeOptions<3>) -> Self {
1003        Self {
1004            stride: value.stride,
1005            padding: value.padding,
1006            padding_out: value.padding_out,
1007            dilation: value.dilation,
1008            groups: value.groups,
1009        }
1010    }
1011}
1012
1013impl From<Conv1dOptionsDescription> for ConvOptions<1> {
1014    fn from(val: Conv1dOptionsDescription) -> Self {
1015        ConvOptions {
1016            stride: val.stride,
1017            padding: val.padding,
1018            dilation: val.dilation,
1019            groups: val.groups,
1020        }
1021    }
1022}
1023
1024impl From<Conv2dOptionsDescription> for ConvOptions<2> {
1025    fn from(val: Conv2dOptionsDescription) -> Self {
1026        ConvOptions {
1027            stride: val.stride,
1028            padding: val.padding,
1029            dilation: val.dilation,
1030            groups: val.groups,
1031        }
1032    }
1033}
1034
1035impl From<Conv3dOptionsDescription> for ConvOptions<3> {
1036    fn from(val: Conv3dOptionsDescription) -> Self {
1037        ConvOptions {
1038            stride: val.stride,
1039            padding: val.padding,
1040            dilation: val.dilation,
1041            groups: val.groups,
1042        }
1043    }
1044}
1045
1046impl From<DeformableConv2dOptionsDescription> for DeformConvOptions<2> {
1047    fn from(value: DeformableConv2dOptionsDescription) -> Self {
1048        DeformConvOptions {
1049            stride: value.stride,
1050            padding: value.padding,
1051            dilation: value.dilation,
1052            weight_groups: value.weight_groups,
1053            offset_groups: value.offset_groups,
1054        }
1055    }
1056}
1057
1058impl From<ConvTranspose1dOptionsDescription> for ConvTransposeOptions<1> {
1059    fn from(val: ConvTranspose1dOptionsDescription) -> Self {
1060        ConvTransposeOptions {
1061            stride: val.stride,
1062            padding: val.padding,
1063            padding_out: val.padding_out,
1064            dilation: val.dilation,
1065            groups: val.groups,
1066        }
1067    }
1068}
1069
1070impl From<ConvTranspose2dOptionsDescription> for ConvTransposeOptions<2> {
1071    fn from(val: ConvTranspose2dOptionsDescription) -> Self {
1072        ConvTransposeOptions {
1073            stride: val.stride,
1074            padding: val.padding,
1075            padding_out: val.padding_out,
1076            dilation: val.dilation,
1077            groups: val.groups,
1078        }
1079    }
1080}
1081
1082impl From<ConvTranspose3dOptionsDescription> for ConvTransposeOptions<3> {
1083    fn from(val: ConvTranspose3dOptionsDescription) -> Self {
1084        ConvTransposeOptions {
1085            stride: val.stride,
1086            padding: val.padding,
1087            padding_out: val.padding_out,
1088            dilation: val.dilation,
1089            groups: val.groups,
1090        }
1091    }
1092}
1093
1094#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1095#[allow(missing_docs)]
1096pub struct AvgPool1dDescription {
1097    pub x: TensorDescription,
1098    pub kernel_size: usize,
1099    pub stride: usize,
1100    pub padding: usize,
1101    pub count_include_pad: bool,
1102    pub out: TensorDescription,
1103}
1104
1105#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1106#[allow(missing_docs)]
1107pub struct AvgPool2dDescription {
1108    pub x: TensorDescription,
1109    pub kernel_size: [usize; 2],
1110    pub stride: [usize; 2],
1111    pub padding: [usize; 2],
1112    pub count_include_pad: bool,
1113    pub out: TensorDescription,
1114}
1115
1116#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1117#[allow(missing_docs)]
1118pub struct AvgPool1dBackwardDescription {
1119    pub x: TensorDescription,
1120    pub grad: TensorDescription,
1121    pub kernel_size: usize,
1122    pub stride: usize,
1123    pub padding: usize,
1124    pub count_include_pad: bool,
1125    pub out: TensorDescription,
1126}
1127
1128#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1129#[allow(missing_docs)]
1130pub struct AvgPool2dBackwardDescription {
1131    pub x: TensorDescription,
1132    pub grad: TensorDescription,
1133    pub kernel_size: [usize; 2],
1134    pub stride: [usize; 2],
1135    pub padding: [usize; 2],
1136    pub count_include_pad: bool,
1137    pub out: TensorDescription,
1138}
1139
1140#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1141#[allow(missing_docs)]
1142pub struct AdaptiveAvgPool1dDescription {
1143    pub x: TensorDescription,
1144    pub output_size: usize,
1145    pub out: TensorDescription,
1146}
1147
1148#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1149#[allow(missing_docs)]
1150pub struct AdaptiveAvgPool2dDescription {
1151    pub x: TensorDescription,
1152    pub output_size: [usize; 2],
1153    pub out: TensorDescription,
1154}
1155
1156#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1157#[allow(missing_docs)]
1158pub struct AdaptiveAvgPool1dBackwardDescription {
1159    pub x: TensorDescription,
1160    pub grad: TensorDescription,
1161    pub out: TensorDescription,
1162}
1163
1164#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1165#[allow(missing_docs)]
1166pub struct AdaptiveAvgPool2dBackwardDescription {
1167    pub x: TensorDescription,
1168    pub grad: TensorDescription,
1169    pub out: TensorDescription,
1170}
1171
1172#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1173#[allow(missing_docs)]
1174pub struct MaxPool1dDescription {
1175    pub x: TensorDescription,
1176    pub kernel_size: usize,
1177    pub stride: usize,
1178    pub padding: usize,
1179    pub dilation: usize,
1180    pub out: TensorDescription,
1181}
1182
1183#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1184#[allow(missing_docs)]
1185pub struct MaxPool1dWithIndicesDescription {
1186    pub x: TensorDescription,
1187    pub kernel_size: usize,
1188    pub stride: usize,
1189    pub padding: usize,
1190    pub dilation: usize,
1191    pub out: TensorDescription,
1192    pub out_indices: TensorDescription,
1193}
1194
1195#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1196#[allow(missing_docs)]
1197pub struct MaxPool1dWithIndicesBackwardDescription {
1198    pub x: TensorDescription,
1199    pub grad: TensorDescription,
1200    pub indices: TensorDescription,
1201    pub kernel_size: usize,
1202    pub stride: usize,
1203    pub padding: usize,
1204    pub dilation: usize,
1205    pub out: TensorDescription,
1206}
1207
1208#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1209#[allow(missing_docs)]
1210pub struct MaxPool2dDescription {
1211    pub x: TensorDescription,
1212    pub kernel_size: [usize; 2],
1213    pub stride: [usize; 2],
1214    pub padding: [usize; 2],
1215    pub dilation: [usize; 2],
1216    pub out: TensorDescription,
1217}
1218
1219#[allow(missing_docs)]
1220#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1221pub struct MaxPool2dWithIndicesDescription {
1222    pub x: TensorDescription,
1223    pub kernel_size: [usize; 2],
1224    pub stride: [usize; 2],
1225    pub padding: [usize; 2],
1226    pub dilation: [usize; 2],
1227    pub out: TensorDescription,
1228    pub out_indices: TensorDescription,
1229}
1230
1231#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1232#[allow(missing_docs)]
1233pub struct MaxPool2dWithIndicesBackwardDescription {
1234    pub x: TensorDescription,
1235    pub grad: TensorDescription,
1236    pub indices: TensorDescription,
1237    pub kernel_size: [usize; 2],
1238    pub stride: [usize; 2],
1239    pub padding: [usize; 2],
1240    pub dilation: [usize; 2],
1241    pub out: TensorDescription,
1242}
1243
1244#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1245#[allow(missing_docs)]
1246pub enum InterpolateModeDescription {
1247    Nearest,
1248    Bilinear,
1249    Bicubic,
1250}
1251
1252#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1253#[allow(missing_docs)]
1254pub struct InterpolateOptionsDescription {
1255    pub mode: InterpolateModeDescription,
1256}
1257
1258#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1259#[allow(missing_docs)]
1260pub struct InterpolateDescription {
1261    pub x: TensorDescription,
1262    pub output_size: [usize; 2],
1263    pub options: InterpolateOptionsDescription,
1264    pub out: TensorDescription,
1265}
1266
1267impl From<InterpolateModeDescription> for InterpolateMode {
1268    fn from(val: InterpolateModeDescription) -> Self {
1269        match val {
1270            InterpolateModeDescription::Nearest => Self::Nearest,
1271            InterpolateModeDescription::Bilinear => Self::Bilinear,
1272            InterpolateModeDescription::Bicubic => Self::Bicubic,
1273        }
1274    }
1275}
1276
1277impl From<InterpolateOptionsDescription> for InterpolateOptions {
1278    fn from(val: InterpolateOptionsDescription) -> Self {
1279        Self {
1280            mode: val.mode.into(),
1281        }
1282    }
1283}
1284
1285impl From<InterpolateMode> for InterpolateModeDescription {
1286    fn from(val: InterpolateMode) -> Self {
1287        match val {
1288            InterpolateMode::Nearest => Self::Nearest,
1289            InterpolateMode::Bilinear => Self::Bilinear,
1290            InterpolateMode::Bicubic => Self::Bicubic,
1291        }
1292    }
1293}
1294
1295impl From<InterpolateOptions> for InterpolateOptionsDescription {
1296    fn from(val: InterpolateOptions) -> Self {
1297        Self {
1298            mode: val.mode.into(),
1299        }
1300    }
1301}
1302
1303#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
1304#[allow(missing_docs)]
1305pub struct InterpolateBackwardDescription {
1306    pub x: TensorDescription,
1307    pub grad: TensorDescription,
1308    pub output_size: [usize; 2],
1309    pub options: InterpolateOptionsDescription,
1310    pub out: TensorDescription,
1311}
1312
1313impl OperationDescription {
1314    /// Cleanup the remaining tensor handles that have not been used.
1315    pub fn nodes(&self) -> Vec<&TensorDescription> {
1316        match self {
1317            OperationDescription::BaseFloat(ops) => ops.nodes(),
1318            OperationDescription::BaseInt(ops) => ops.nodes(),
1319            OperationDescription::BaseBool(ops) => ops.nodes(),
1320            OperationDescription::NumericFloat(_dtype, ops) => ops.nodes(),
1321            OperationDescription::NumericInt(_dtype, ops) => ops.nodes(),
1322            OperationDescription::Bool(ops) => ops.nodes(),
1323            OperationDescription::Int(ops) => ops.nodes(),
1324            OperationDescription::Float(_dtype, ops) => ops.nodes(),
1325            OperationDescription::Module(ops) => ops.nodes(),
1326            OperationDescription::Custom(ops) => ops.nodes(),
1327        }
1328    }
1329}
1330
1331impl BaseOperationDescription {
1332    fn nodes(&self) -> Vec<&TensorDescription> {
1333        match self {
1334            BaseOperationDescription::ToDevice(desc) => vec![desc],
1335            BaseOperationDescription::Reshape(desc) => {
1336                vec![&desc.input, &desc.out]
1337            }
1338            BaseOperationDescription::SwapDims(desc) => {
1339                vec![&desc.input, &desc.out]
1340            }
1341            BaseOperationDescription::Permute(desc) => {
1342                vec![&desc.input, &desc.out]
1343            }
1344
1345            BaseOperationDescription::Expand(desc) => {
1346                vec![&desc.input, &desc.out]
1347            }
1348
1349            BaseOperationDescription::Flip(desc) => {
1350                vec![&desc.input, &desc.out]
1351            }
1352            BaseOperationDescription::Slice(desc) => {
1353                vec![&desc.tensor, &desc.out]
1354            }
1355            BaseOperationDescription::SliceAssign(desc) => {
1356                vec![&desc.tensor, &desc.value, &desc.out]
1357            }
1358            BaseOperationDescription::Equal(desc) => {
1359                vec![&desc.lhs, &desc.rhs, &desc.out]
1360            }
1361            BaseOperationDescription::RepeatDim(desc) => {
1362                vec![&desc.tensor, &desc.out]
1363            }
1364            BaseOperationDescription::Cat(desc) => desc.tensors.iter().collect(),
1365            BaseOperationDescription::Cast(desc) => vec![&desc.input, &desc.out],
1366            BaseOperationDescription::Empty(desc) => vec![desc],
1367        }
1368    }
1369}
1370
1371impl<E: Element> NumericOperationDescription<E> {
1372    fn nodes(&self) -> Vec<&TensorDescription> {
1373        match self {
1374            NumericOperationDescription::Add(desc) => {
1375                vec![&desc.lhs, &desc.rhs, &desc.out]
1376            }
1377            NumericOperationDescription::AddScalar(desc) => {
1378                vec![&desc.lhs, &desc.out]
1379            }
1380            NumericOperationDescription::Sub(desc) => {
1381                vec![&desc.lhs, &desc.rhs, &desc.out]
1382            }
1383            NumericOperationDescription::SubScalar(desc) => {
1384                vec![&desc.lhs, &desc.out]
1385            }
1386            NumericOperationDescription::Mul(desc) => {
1387                vec![&desc.lhs, &desc.rhs, &desc.out]
1388            }
1389            NumericOperationDescription::MulScalar(desc) => {
1390                vec![&desc.lhs, &desc.out]
1391            }
1392            NumericOperationDescription::Div(desc) => {
1393                vec![&desc.lhs, &desc.rhs, &desc.out]
1394            }
1395            NumericOperationDescription::DivScalar(desc) => {
1396                vec![&desc.lhs, &desc.out]
1397            }
1398            NumericOperationDescription::Rem(desc) => {
1399                vec![&desc.lhs, &desc.rhs, &desc.out]
1400            }
1401            NumericOperationDescription::RemScalar(desc) => {
1402                vec![&desc.lhs, &desc.out]
1403            }
1404            NumericOperationDescription::Ones(desc) => vec![desc],
1405            NumericOperationDescription::Gather(desc) => {
1406                vec![&desc.tensor, &desc.indices, &desc.out]
1407            }
1408            NumericOperationDescription::Scatter(desc) => {
1409                vec![&desc.tensor, &desc.indices, &desc.value, &desc.out]
1410            }
1411            NumericOperationDescription::Select(desc) => {
1412                vec![&desc.tensor, &desc.indices, &desc.out]
1413            }
1414            NumericOperationDescription::SelectAssign(desc) => {
1415                vec![&desc.tensor, &desc.indices, &desc.value, &desc.out]
1416            }
1417            NumericOperationDescription::MaskWhere(desc) => {
1418                vec![&desc.tensor, &desc.mask, &desc.value, &desc.out]
1419            }
1420            NumericOperationDescription::MaskFill(desc) => {
1421                vec![&desc.tensor, &desc.mask, &desc.out]
1422            }
1423            NumericOperationDescription::EqualElem(desc) => {
1424                vec![&desc.lhs, &desc.out]
1425            }
1426            NumericOperationDescription::GreaterElem(desc) => {
1427                vec![&desc.lhs, &desc.out]
1428            }
1429            NumericOperationDescription::GreaterEqualElem(desc) => {
1430                vec![&desc.lhs, &desc.out]
1431            }
1432            NumericOperationDescription::LowerElem(desc) => {
1433                vec![&desc.lhs, &desc.out]
1434            }
1435            NumericOperationDescription::LowerEqualElem(desc) => {
1436                vec![&desc.lhs, &desc.out]
1437            }
1438            NumericOperationDescription::Greater(desc) => {
1439                vec![&desc.lhs, &desc.rhs, &desc.out]
1440            }
1441            NumericOperationDescription::GreaterEqual(desc) => {
1442                vec![&desc.lhs, &desc.rhs, &desc.out]
1443            }
1444            NumericOperationDescription::Lower(desc) => {
1445                vec![&desc.lhs, &desc.rhs, &desc.out]
1446            }
1447            NumericOperationDescription::LowerEqual(desc) => {
1448                vec![&desc.lhs, &desc.rhs, &desc.out]
1449            }
1450            NumericOperationDescription::ArgMax(desc) => {
1451                vec![&desc.lhs, &desc.out]
1452            }
1453            NumericOperationDescription::ArgMin(desc) => {
1454                vec![&desc.lhs, &desc.out]
1455            }
1456            NumericOperationDescription::Clamp(desc) => {
1457                vec![&desc.tensor, &desc.out]
1458            }
1459            NumericOperationDescription::Abs(desc) => {
1460                vec![&desc.input, &desc.out]
1461            }
1462            NumericOperationDescription::Zeros(desc) => vec![desc],
1463            NumericOperationDescription::Full(desc) => vec![&desc.0],
1464            NumericOperationDescription::MeanDim(desc) => {
1465                vec![&desc.lhs, &desc.out]
1466            }
1467            NumericOperationDescription::Mean(desc) => {
1468                vec![&desc.input, &desc.out]
1469            }
1470            NumericOperationDescription::Sum(desc) => {
1471                vec![&desc.input, &desc.out]
1472            }
1473            NumericOperationDescription::SumDim(desc) => {
1474                vec![&desc.lhs, &desc.out]
1475            }
1476            NumericOperationDescription::Prod(desc) => {
1477                vec![&desc.input, &desc.out]
1478            }
1479            NumericOperationDescription::ProdDim(desc) => {
1480                vec![&desc.lhs, &desc.out]
1481            }
1482            NumericOperationDescription::Max(desc) => {
1483                vec![&desc.input, &desc.out]
1484            }
1485            NumericOperationDescription::MaxDimWithIndices(desc) => {
1486                vec![&desc.tensor, &desc.out_indices, &desc.out]
1487            }
1488            NumericOperationDescription::MinDimWithIndices(desc) => {
1489                vec![&desc.tensor, &desc.out_indices, &desc.out]
1490            }
1491            NumericOperationDescription::Min(desc) => {
1492                vec![&desc.input, &desc.out]
1493            }
1494            NumericOperationDescription::MaxDim(desc) => {
1495                vec![&desc.lhs, &desc.out]
1496            }
1497            NumericOperationDescription::MinDim(desc) => {
1498                vec![&desc.lhs, &desc.out]
1499            }
1500            NumericOperationDescription::IntRandom(desc) => {
1501                vec![&desc.out]
1502            }
1503            NumericOperationDescription::Powf(desc) => {
1504                vec![&desc.lhs, &desc.rhs, &desc.out]
1505            }
1506        }
1507    }
1508}
1509
1510impl FloatOperationDescription {
1511    fn nodes(&self) -> Vec<&TensorDescription> {
1512        match self {
1513            FloatOperationDescription::Matmul(desc) => {
1514                vec![&desc.lhs, &desc.rhs, &desc.out]
1515            }
1516            FloatOperationDescription::Random(desc) => vec![&desc.out],
1517            FloatOperationDescription::Exp(desc) => vec![&desc.input, &desc.out],
1518            FloatOperationDescription::Log(desc) => vec![&desc.input, &desc.out],
1519            FloatOperationDescription::Log1p(desc) => vec![&desc.input, &desc.out],
1520            FloatOperationDescription::Erf(desc) => vec![&desc.input, &desc.out],
1521            FloatOperationDescription::Recip(desc) => vec![&desc.input, &desc.out],
1522            FloatOperationDescription::PowfScalar(desc) => vec![&desc.lhs, &desc.out],
1523            FloatOperationDescription::Sqrt(desc) => vec![&desc.input, &desc.out],
1524            FloatOperationDescription::Cos(desc) => vec![&desc.input, &desc.out],
1525            FloatOperationDescription::Sin(desc) => vec![&desc.input, &desc.out],
1526            FloatOperationDescription::Tanh(desc) => vec![&desc.input, &desc.out],
1527            FloatOperationDescription::Round(desc) => vec![&desc.input, &desc.out],
1528            FloatOperationDescription::Floor(desc) => vec![&desc.input, &desc.out],
1529            FloatOperationDescription::Ceil(desc) => vec![&desc.input, &desc.out],
1530            FloatOperationDescription::IntoInt(desc) => vec![&desc.input, &desc.out],
1531            FloatOperationDescription::Quantize(desc) => {
1532                if let Some(offset) = &desc.qparams.offset {
1533                    vec![&desc.tensor, &desc.qparams.scale, &offset, &desc.out]
1534                } else {
1535                    vec![&desc.tensor, &desc.qparams.scale, &desc.out]
1536                }
1537            }
1538            FloatOperationDescription::Dequantize(desc) => vec![&desc.input, &desc.out],
1539        }
1540    }
1541}
1542
1543impl IntOperationDescription {
1544    fn nodes(&self) -> Vec<&TensorDescription> {
1545        match self {
1546            IntOperationDescription::IntoFloat(desc) => vec![&desc.input, &desc.out],
1547        }
1548    }
1549}
1550
1551impl BoolOperationDescription {
1552    fn nodes(&self) -> Vec<&TensorDescription> {
1553        match self {
1554            BoolOperationDescription::IntoFloat(desc) => vec![&desc.input, &desc.out],
1555            BoolOperationDescription::IntoInt(desc) => vec![&desc.input, &desc.out],
1556            BoolOperationDescription::Not(desc) => vec![&desc.input, &desc.out],
1557        }
1558    }
1559}
1560
1561impl ModuleOperationDescription {
1562    fn nodes(&self) -> Vec<&TensorDescription> {
1563        match self {
1564            ModuleOperationDescription::Embedding(desc) => {
1565                vec![&desc.weights, &desc.indices, &desc.out]
1566            }
1567            ModuleOperationDescription::EmbeddingBackward(desc) => {
1568                vec![&desc.weights, &desc.out_grad, &desc.indices, &desc.out]
1569            }
1570            ModuleOperationDescription::Conv1d(desc) => {
1571                if let Some(bias) = &desc.bias {
1572                    vec![&desc.x, &desc.weight, &bias, &desc.out]
1573                } else {
1574                    vec![&desc.x, &desc.weight, &desc.out]
1575                }
1576            }
1577            ModuleOperationDescription::Conv2d(desc) => {
1578                if let Some(bias) = &desc.bias {
1579                    vec![&desc.x, &desc.weight, &bias, &desc.out]
1580                } else {
1581                    vec![&desc.x, &desc.weight, &desc.out]
1582                }
1583            }
1584            ModuleOperationDescription::Conv3d(desc) => {
1585                if let Some(bias) = &desc.bias {
1586                    vec![&desc.x, &desc.weight, &bias, &desc.out]
1587                } else {
1588                    vec![&desc.x, &desc.weight, &desc.out]
1589                }
1590            }
1591            ModuleOperationDescription::DeformableConv2d(desc) => match (&desc.mask, &desc.bias) {
1592                (Some(mask), Some(bias)) => vec![&desc.x, &desc.offset, &desc.weight, &mask, &bias],
1593                (Some(mask), None) => vec![&desc.x, &desc.offset, &desc.weight, &mask],
1594                (None, Some(bias)) => vec![&desc.x, &desc.offset, &desc.weight, &bias],
1595                (None, None) => vec![&desc.x, &desc.offset, &desc.weight],
1596            },
1597            ModuleOperationDescription::DeformableConv2dBackward(desc) => {
1598                match (&desc.mask, &desc.bias) {
1599                    (Some(mask), Some(bias)) => {
1600                        vec![&desc.x, &desc.offset, &desc.weight, &mask, &bias]
1601                    }
1602                    (Some(mask), None) => vec![&desc.x, &desc.offset, &desc.weight, &mask],
1603                    (None, Some(bias)) => vec![&desc.x, &desc.offset, &desc.weight, &bias],
1604                    (None, None) => vec![&desc.x, &desc.offset, &desc.weight],
1605                }
1606            }
1607            ModuleOperationDescription::ConvTranspose1d(desc) => {
1608                if let Some(bias) = &desc.bias {
1609                    vec![&desc.x, &desc.weight, &bias, &desc.out]
1610                } else {
1611                    vec![&desc.x, &desc.weight, &desc.out]
1612                }
1613            }
1614            ModuleOperationDescription::ConvTranspose2d(desc) => {
1615                if let Some(bias) = &desc.bias {
1616                    vec![&desc.x, &desc.weight, &bias, &desc.out]
1617                } else {
1618                    vec![&desc.x, &desc.weight, &desc.out]
1619                }
1620            }
1621            ModuleOperationDescription::ConvTranspose3d(desc) => {
1622                if let Some(bias) = &desc.bias {
1623                    vec![&desc.x, &desc.weight, &bias, &desc.out]
1624                } else {
1625                    vec![&desc.x, &desc.weight, &desc.out]
1626                }
1627            }
1628            ModuleOperationDescription::AvgPool1d(desc) => {
1629                vec![&desc.x, &desc.out]
1630            }
1631            ModuleOperationDescription::AvgPool2d(desc) => {
1632                vec![&desc.x, &desc.out]
1633            }
1634            ModuleOperationDescription::AvgPool1dBackward(desc) => {
1635                vec![&desc.x, &desc.out, &desc.grad]
1636            }
1637            ModuleOperationDescription::AvgPool2dBackward(desc) => {
1638                vec![&desc.x, &desc.out, &desc.grad]
1639            }
1640            ModuleOperationDescription::AdaptiveAvgPool1d(desc) => {
1641                vec![&desc.x, &desc.out]
1642            }
1643            ModuleOperationDescription::AdaptiveAvgPool2d(desc) => {
1644                vec![&desc.x, &desc.out]
1645            }
1646            ModuleOperationDescription::AdaptiveAvgPool1dBackward(desc) => {
1647                vec![&desc.x, &desc.out, &desc.grad]
1648            }
1649            ModuleOperationDescription::AdaptiveAvgPool2dBackward(desc) => {
1650                vec![&desc.x, &desc.out, &desc.grad]
1651            }
1652            ModuleOperationDescription::MaxPool1d(desc) => {
1653                vec![&desc.x, &desc.out]
1654            }
1655            ModuleOperationDescription::MaxPool1dWithIndices(desc) => {
1656                vec![&desc.x, &desc.out, &desc.out_indices]
1657            }
1658            ModuleOperationDescription::MaxPool1dWithIndicesBackward(desc) => {
1659                vec![&desc.x, &desc.out, &desc.indices, &desc.grad]
1660            }
1661            ModuleOperationDescription::MaxPool2d(desc) => {
1662                vec![&desc.x, &desc.out]
1663            }
1664            ModuleOperationDescription::MaxPool2dWithIndices(desc) => {
1665                vec![&desc.x, &desc.out, &desc.out_indices]
1666            }
1667            ModuleOperationDescription::MaxPool2dWithIndicesBackward(desc) => {
1668                vec![&desc.x, &desc.out, &desc.indices, &desc.grad]
1669            }
1670            ModuleOperationDescription::Interpolate(desc) => {
1671                vec![&desc.x, &desc.out]
1672            }
1673            ModuleOperationDescription::InterpolateBackward(desc) => {
1674                vec![&desc.x, &desc.out, &desc.grad]
1675            }
1676        }
1677    }
1678}
1679
1680impl core::hash::Hash for RandomOperationDescription {
1681    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
1682        self.out.hash(state);
1683
1684        match self.distribution {
1685            Distribution::Default => 1u8.hash(state),
1686            Distribution::Bernoulli(_) => 2u8.hash(state),
1687            Distribution::Uniform(_, _) => 3u8.hash(state),
1688            Distribution::Normal(_, _) => 4u8.hash(state),
1689        }
1690    }
1691}
1692
1693impl<E> core::hash::Hash for ScalarOperationDescription<E> {
1694    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
1695        self.lhs.hash(state);
1696        self.out.hash(state);
1697    }
1698}
1699
1700impl<E> core::hash::Hash for MaskFillOperationDescription<E> {
1701    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
1702        self.tensor.hash(state);
1703        self.mask.hash(state);
1704        self.out.hash(state);
1705    }
1706}
1707
1708impl<E> core::hash::Hash for ClampOperationDescription<E> {
1709    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
1710        self.tensor.hash(state);
1711        self.out.hash(state);
1712    }
1713}
1714
1715impl<E> core::hash::Hash for NumericOperationDescription<E> {
1716    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
1717        match self {
1718            NumericOperationDescription::Add(desc) => desc.hash(state),
1719            NumericOperationDescription::AddScalar(desc) => desc.hash(state),
1720            NumericOperationDescription::Sub(desc) => desc.hash(state),
1721            NumericOperationDescription::SubScalar(desc) => desc.hash(state),
1722            NumericOperationDescription::Div(desc) => desc.hash(state),
1723            NumericOperationDescription::DivScalar(desc) => desc.hash(state),
1724            NumericOperationDescription::Rem(desc) => desc.hash(state),
1725            NumericOperationDescription::RemScalar(desc) => desc.hash(state),
1726            NumericOperationDescription::Mul(desc) => desc.hash(state),
1727            NumericOperationDescription::MulScalar(desc) => desc.hash(state),
1728            NumericOperationDescription::Abs(desc) => desc.hash(state),
1729            NumericOperationDescription::Ones(desc) => desc.hash(state),
1730            NumericOperationDescription::Zeros(desc) => desc.hash(state),
1731            NumericOperationDescription::Full(desc) => desc.0.hash(state),
1732            NumericOperationDescription::Gather(desc) => desc.hash(state),
1733            NumericOperationDescription::Scatter(desc) => desc.hash(state),
1734            NumericOperationDescription::Select(desc) => desc.hash(state),
1735            NumericOperationDescription::SelectAssign(desc) => desc.hash(state),
1736            NumericOperationDescription::MaskWhere(desc) => desc.hash(state),
1737            NumericOperationDescription::MaskFill(desc) => desc.hash(state),
1738            NumericOperationDescription::MeanDim(desc) => desc.hash(state),
1739            NumericOperationDescription::Mean(desc) => desc.hash(state),
1740            NumericOperationDescription::Sum(desc) => desc.hash(state),
1741            NumericOperationDescription::SumDim(desc) => desc.hash(state),
1742            NumericOperationDescription::Prod(desc) => desc.hash(state),
1743            NumericOperationDescription::ProdDim(desc) => desc.hash(state),
1744            NumericOperationDescription::EqualElem(desc) => desc.hash(state),
1745            NumericOperationDescription::Greater(desc) => desc.hash(state),
1746            NumericOperationDescription::GreaterElem(desc) => desc.hash(state),
1747            NumericOperationDescription::GreaterEqual(desc) => desc.hash(state),
1748            NumericOperationDescription::GreaterEqualElem(desc) => desc.hash(state),
1749            NumericOperationDescription::Lower(desc) => desc.hash(state),
1750            NumericOperationDescription::LowerElem(desc) => desc.hash(state),
1751            NumericOperationDescription::LowerEqual(desc) => desc.hash(state),
1752            NumericOperationDescription::LowerEqualElem(desc) => desc.hash(state),
1753            NumericOperationDescription::ArgMax(desc) => desc.hash(state),
1754            NumericOperationDescription::ArgMin(desc) => desc.hash(state),
1755            NumericOperationDescription::Max(desc) => desc.hash(state),
1756            NumericOperationDescription::MaxDimWithIndices(desc) => desc.hash(state),
1757            NumericOperationDescription::MinDimWithIndices(desc) => desc.hash(state),
1758            NumericOperationDescription::Min(desc) => desc.hash(state),
1759            NumericOperationDescription::MaxDim(desc) => desc.hash(state),
1760            NumericOperationDescription::MinDim(desc) => desc.hash(state),
1761            NumericOperationDescription::Clamp(desc) => desc.hash(state),
1762            NumericOperationDescription::IntRandom(desc) => desc.hash(state),
1763            NumericOperationDescription::Powf(desc) => desc.hash(state),
1764        }
1765    }
1766}