Skip to main content

burn_ir/
builder.rs

1#![allow(missing_docs)]
2
3use alloc::vec::Vec;
4use burn_backend::{
5    DType, Distribution, Shape, Slice, calculate_matmul_output,
6    ops::{
7        conv::{
8            calculate_conv_output_shape, calculate_conv_transpose_output_shape,
9            calculate_pool_output_shape,
10        },
11        unfold::calculate_unfold_shape,
12    },
13    quantization::QuantScheme,
14    tensor::IndexingUpdateOp,
15};
16
17use crate::{ScalarIr, TensorId, TensorIr};
18
19use super::operation::*;
20
21impl CreationOpIr {
22    pub fn create(shape: Shape, dtype: DType, new_id: impl FnOnce() -> TensorId) -> Self {
23        let out = TensorIr::uninit(new_id(), shape, dtype);
24
25        CreationOpIr { out }
26    }
27}
28
29impl InitOperationIr {
30    pub fn create(shape: Shape, dtype: DType, new_id: impl FnOnce() -> TensorId) -> Self {
31        let out = TensorIr::uninit(new_id(), shape, dtype);
32
33        InitOperationIr { out }
34    }
35}
36
37impl RandomOpIr {
38    pub fn create(
39        shape: Shape,
40        dtype: DType,
41        distribution: Distribution,
42        new_id: impl FnOnce() -> TensorId,
43    ) -> Self {
44        let out = TensorIr::uninit(new_id(), shape, dtype);
45
46        RandomOpIr { out, distribution }
47    }
48}
49
50impl FullOpIr {
51    pub fn create(
52        shape: Shape,
53        dtype: DType,
54        value: ScalarIr,
55        new_id: impl FnOnce() -> TensorId,
56    ) -> Self {
57        let out = TensorIr::uninit(new_id(), shape, dtype);
58
59        FullOpIr { out, value }
60    }
61}
62
63impl CastOpIr {
64    pub fn create(input: TensorIr, dtype: DType, new_id: impl FnOnce() -> TensorId) -> Self {
65        let out = TensorIr::uninit(new_id(), input.shape.clone(), dtype);
66        CastOpIr { input, out }
67    }
68}
69
70impl ShapeOpIr {
71    pub fn expand(input: TensorIr, shape: Shape, new_id: impl FnOnce() -> TensorId) -> Self {
72        let shape = input.shape.expand(shape).unwrap();
73        Self::create(input, shape, new_id)
74    }
75
76    pub fn reshape(input: TensorIr, shape: Shape, new_id: impl FnOnce() -> TensorId) -> Self {
77        let shape = input.shape.reshape(shape).unwrap();
78        Self::create(input, shape, new_id)
79    }
80
81    fn create(input: TensorIr, shape: Shape, new_id: impl FnOnce() -> TensorId) -> Self {
82        let out = TensorIr::uninit(new_id(), shape, input.dtype);
83        ShapeOpIr { input, out }
84    }
85}
86
87// "Lower" specific operations into a binary or unary op representation.
88// Useful when collecting inputs and outputs and don't care about the other semantics.
89impl From<MatmulOpIr> for BinaryOpIr {
90    fn from(value: MatmulOpIr) -> Self {
91        Self {
92            lhs: value.lhs,
93            rhs: value.rhs,
94            out: value.out,
95        }
96    }
97}
98
99impl From<ReduceOpIr> for UnaryOpIr {
100    fn from(value: ReduceOpIr) -> Self {
101        Self {
102            input: value.input,
103            out: value.out,
104        }
105    }
106}
107
108#[derive(Debug)]
109#[allow(missing_docs)]
110pub enum IrError {
111    DTypeMismatch,
112}
113
114fn dtype_compat(lhs: &DType, rhs: &DType) -> bool {
115    let lhs_qfloat = matches!(lhs, DType::QFloat(_));
116    let rhs_qfloat = matches!(rhs, DType::QFloat(_));
117    if lhs_qfloat && (rhs_qfloat || rhs.is_float())
118        || lhs.is_float() && (rhs_qfloat || rhs.is_float())
119    {
120        true
121    } else {
122        lhs == rhs
123    }
124}
125
126fn output_check<'a, I>(inputs: I, compat: impl Fn(&DType, &DType) -> bool) -> Result<DType, IrError>
127where
128    I: IntoIterator<Item = &'a DType>,
129{
130    let mut iter = inputs.into_iter();
131    let first = iter.next().unwrap();
132    for d in iter {
133        if !compat(first, d) {
134            return Err(IrError::DTypeMismatch);
135        }
136    }
137    Ok(*first)
138}
139
140fn output_dtype<'a, I: IntoIterator<Item = &'a DType>>(inputs: I) -> Result<DType, IrError> {
141    output_check(inputs, |a, b| a == b)
142}
143
144fn output_dtype_mixed<'a, I: IntoIterator<Item = &'a DType>>(inputs: I) -> Result<DType, IrError> {
145    output_check(inputs, dtype_compat)
146}
147
148/// Macro to implement `create` constructors for operations with a single output.
149///
150/// Supports shape and dtype validation.
151macro_rules! impl_ir_create {
152    (@create_fn $op:ident { $( $field:ident : $ty:ty ),* $(,)? } , $shape:expr, $dtype:expr) => {
153        #[doc = "Create a new operation IR from the given inputs."]
154        #[doc = "`new_id` should generate a unique `TensorId` for the uninitialized output tensor."]
155        #[allow(clippy::too_many_arguments)]
156        pub fn create($( $field : $ty ),*, new_id: impl FnOnce() -> crate::TensorId) -> $op {
157            let shape = $shape;
158            let dtype = $dtype;
159            let out = TensorIr::uninit(new_id(), shape, dtype);
160            $op { $( $field ),*, out }
161        }
162    };
163
164    // Case: simple op, single `create`
165    (
166        $op:ident { $( $field:ident : $ty:ty ),* $(,)? },
167        shape = $shape:expr,
168        dtype = $dtype:expr
169    ) => {
170        impl $op {
171            impl_ir_create!(@create_fn $op { $( $field : $ty ),* }, $shape, $dtype);
172        }
173    };
174
175    // Case: op with one additional constructor that accepts an explicit output dtype
176    (
177        $op:ident { $( $field:ident : $ty:ty ),* $(,)? },
178        shape = $shape:expr,
179        dtype = $dtype:expr,
180        $fn_name:ident ( $extra:ident : $extra_ty:ty )
181    ) => {
182        impl $op {
183            impl_ir_create!(@create_fn $op { $( $field : $ty ),* }, $shape, $dtype);
184
185            #[doc = "Create a new operation IR from the given inputs and the given output dtype."]
186            #[allow(clippy::too_many_arguments)]
187            pub fn $fn_name($( $field : $ty ),*, $extra: $extra_ty, new_id: impl FnOnce() -> crate::TensorId) -> Self {
188                let shape = $shape;
189                let _ = $dtype; // still validates dtype if needed
190                let out = TensorIr::uninit(new_id(), shape, $extra);
191                $op { $( $field ),*, out }
192            }
193        }
194    };
195}
196
197impl_ir_create!(
198    UnaryOpIr { input: TensorIr },
199    shape = input.shape.clone(),
200    dtype = input.dtype,
201    // Additional constructor for unary comparisons
202    create_comparison(bool_dtype: DType)
203);
204
205impl_ir_create!(
206    BinaryOpIr {
207        lhs: TensorIr,
208        rhs: TensorIr
209    },
210    shape = lhs.shape.broadcast(&rhs.shape).unwrap(),
211    dtype = output_dtype([&lhs.dtype, &rhs.dtype]).unwrap(),
212    // Additional constructor for binary comparisons
213    create_comparison(bool_dtype: DType)
214);
215
216impl_ir_create!(
217    ScalarOpIr {
218        lhs: TensorIr,
219        rhs: ScalarIr
220    },
221    shape = lhs.shape.clone(),
222    dtype = lhs.dtype,
223    // Additional constructor for scalar comparisons
224    create_comparison(bool_dtype: DType)
225);
226
227impl_ir_create!(
228    MatmulOpIr {
229        lhs: TensorIr,
230        rhs: TensorIr
231    },
232    shape = calculate_matmul_output(&lhs.shape, &rhs.shape).unwrap(),
233    dtype = output_dtype_mixed([&lhs.dtype, &rhs.dtype]).unwrap(),
234    // Additional constructor for mixed dtypes
235    create_mixed(out_dtype: DType)
236);
237
238impl_ir_create!(
239    SwapDimsOpIr {
240        input: TensorIr,
241        dim1: usize,
242        dim2: usize
243    },
244    shape = input.shape.clone().swap(dim1, dim2).unwrap(),
245    dtype = input.dtype
246);
247
248impl_ir_create!(
249    PermuteOpIr { input: TensorIr, axes: Vec<usize> },
250    shape = input.shape.clone().permute(&axes).unwrap(),
251    dtype = input.dtype
252);
253
254impl_ir_create!(
255    RepeatDimOpIr {
256        tensor: TensorIr,
257        dim: usize,
258        times: usize
259    },
260    shape = tensor.shape.clone().repeat(dim, times).unwrap(),
261    dtype = tensor.dtype
262);
263
264impl_ir_create!(
265    FlipOpIr { input: TensorIr, axes: Vec<usize> },
266    shape = input.shape.clone(), // TODO: check if axes are within the tensor dimensions
267    dtype = input.dtype
268);
269
270impl_ir_create!(
271    CatOpIr { tensors: Vec<TensorIr>, dim: usize },
272    shape = Shape::cat(tensors.iter().map(|t| &t.shape), dim).unwrap(),
273    dtype = output_dtype(tensors.iter().map(|t| &t.dtype)).unwrap()
274);
275
276impl_ir_create!(
277    GatherOpIr {
278        tensor: TensorIr,
279        dim: usize,
280        indices: TensorIr
281    },
282    shape = indices.shape.clone(), // TODO: check dims compat between tensor and indices
283    dtype = tensor.dtype
284);
285
286impl_ir_create!(
287    ScatterOpIr {
288        tensor: TensorIr,
289        dim: usize,
290        indices: TensorIr,
291        value: TensorIr,
292        update: IndexingUpdateOp
293    },
294    shape = tensor.shape.clone(), // TODO: check dims compat between tensor and indices
295    dtype = output_dtype([&tensor.dtype, &value.dtype]).unwrap()
296);
297
298impl_ir_create!(
299    ReduceOpIr { input: TensorIr },
300    shape = [1].into(),
301    dtype = input.dtype
302);
303
304impl_ir_create!(
305    ReduceDimOpIr {
306        input: TensorIr,
307        axis: usize
308    },
309    shape = input.shape.clone().reduce(axis).unwrap(),
310    dtype = input.dtype,
311    // Additional constructor for argument reduction
312    create_arg(ind_dtype: DType)
313);
314
315impl_ir_create!(
316    DimOpIr {
317        input: TensorIr,
318        axis: usize
319    },
320    shape = input.shape.clone(), // TODO: check dims within rank
321    dtype = input.dtype
322);
323
324impl_ir_create!(
325    SelectOpIr {
326        tensor: TensorIr,
327        dim: usize,
328        indices: TensorIr
329    },
330    // TODO: shape.select?
331    shape = {
332        let mut s = tensor.shape.clone();
333        s[dim] = indices.shape[0];
334        s
335    },
336    dtype = tensor.dtype
337);
338
339impl_ir_create!(
340    SelectAssignOpIr {
341        tensor: TensorIr,
342        dim: usize,
343        indices: TensorIr,
344        value: TensorIr,
345        update: IndexingUpdateOp
346    },
347    // TODO: check value and indices shape match for dim
348    shape = tensor.shape.clone(),
349    dtype = output_dtype([&tensor.dtype, &value.dtype]).unwrap()
350);
351
352impl_ir_create!(
353    SliceOpIr {
354        tensor: TensorIr,
355        ranges: Vec<Slice>,
356    },
357    shape = tensor.shape.clone().slice(&ranges).unwrap(),
358    dtype = tensor.dtype
359);
360
361impl_ir_create!(
362    SliceAssignOpIr {
363        tensor: TensorIr,
364        ranges: Vec<Slice>,
365        value: TensorIr
366    },
367    // TODO: check slice and value number of elements match
368    shape = tensor.shape.clone(),
369    dtype = output_dtype([&tensor.dtype, &value.dtype]).unwrap()
370);
371
372impl_ir_create!(
373    MaskWhereOpIr {
374        tensor: TensorIr,
375        mask: TensorIr,
376        value: TensorIr
377    },
378    shape = Shape::broadcast_many([&tensor.shape, &mask.shape, &value.shape]).unwrap(),
379    dtype = output_dtype([&tensor.dtype, &value.dtype]).unwrap()
380);
381
382impl_ir_create!(
383    MaskFillOpIr {
384        tensor: TensorIr,
385        mask: TensorIr,
386        value: ScalarIr
387    },
388    shape = tensor.shape.broadcast(&mask.shape).unwrap(),
389    dtype = tensor.dtype
390);
391
392impl_ir_create!(
393    ClampOpIr {
394        tensor: TensorIr,
395        min: ScalarIr,
396        max: ScalarIr
397    },
398    shape = tensor.shape.clone(),
399    dtype = tensor.dtype
400);
401
402impl_ir_create!(
403    AvgPool1dOpIr {
404        x: TensorIr,
405        kernel_size: usize,
406        stride: usize,
407        padding: usize,
408        count_include_pad: bool,
409        ceil_mode: bool
410    },
411    shape = calculate_pool_output_shape(
412        &x.shape,
413        &[kernel_size],
414        &[stride],
415        &[padding],
416        &[1],
417        ceil_mode
418    )
419    .unwrap(),
420    dtype = x.dtype
421);
422
423impl_ir_create!(
424    AvgPool1dBackwardOpIr {
425        x: TensorIr,
426        grad: TensorIr,
427        kernel_size: usize,
428        stride: usize,
429        padding: usize,
430        count_include_pad: bool,
431        ceil_mode: bool
432    },
433    shape = x.shape.clone(),
434    dtype = x.dtype
435);
436
437impl_ir_create!(
438    AvgPool2dOpIr {
439        x: TensorIr,
440        kernel_size: [usize; 2],
441        stride: [usize; 2],
442        padding: [usize; 2],
443        count_include_pad: bool,
444        ceil_mode: bool
445    },
446    shape = calculate_pool_output_shape(
447        &x.shape,
448        &kernel_size,
449        &stride,
450        &padding,
451        &[1, 1],
452        ceil_mode
453    )
454    .unwrap(),
455    dtype = x.dtype
456);
457
458impl_ir_create!(
459    AvgPool2dBackwardOpIr {
460        x: TensorIr,
461        grad: TensorIr,
462        kernel_size: [usize; 2],
463        stride: [usize; 2],
464        padding: [usize; 2],
465        count_include_pad: bool,
466        ceil_mode: bool
467    },
468    shape = x.shape.clone(),
469    dtype = x.dtype
470);
471
472impl_ir_create!(
473    MaxPool1dOpIr {
474        x: TensorIr,
475        kernel_size: usize,
476        stride: usize,
477        padding: usize,
478        dilation: usize,
479        ceil_mode: bool
480    },
481    shape = calculate_pool_output_shape(
482        &x.shape,
483        &[kernel_size],
484        &[stride],
485        &[padding],
486        &[dilation],
487        ceil_mode
488    )
489    .unwrap(),
490    dtype = x.dtype
491);
492
493impl_ir_create!(
494    MaxPool2dOpIr {
495        x: TensorIr,
496        kernel_size: [usize; 2],
497        stride: [usize; 2],
498        padding: [usize; 2],
499        dilation: [usize; 2],
500        ceil_mode: bool
501    },
502    shape = calculate_pool_output_shape(
503        &x.shape,
504        &kernel_size,
505        &stride,
506        &padding,
507        &dilation,
508        ceil_mode
509    )
510    .unwrap(),
511    dtype = x.dtype
512);
513
514impl_ir_create!(
515    MaxPool1dWithIndicesBackwardOpIr {
516        x: TensorIr,
517        grad: TensorIr,
518        indices: TensorIr,
519        kernel_size: usize,
520        stride: usize,
521        padding: usize,
522        dilation: usize,
523        ceil_mode: bool
524    },
525    shape = x.shape.clone(),
526    dtype = x.dtype
527);
528
529impl_ir_create!(
530    MaxPool2dWithIndicesBackwardOpIr {
531        x: TensorIr,
532        grad: TensorIr,
533        indices: TensorIr,
534        kernel_size: [usize; 2],
535        stride: [usize; 2],
536        padding: [usize; 2],
537        dilation: [usize; 2],
538        ceil_mode: bool
539    },
540    shape = x.shape.clone(),
541    dtype = x.dtype
542);
543
544impl_ir_create!(
545    AdaptiveAvgPool1dOpIr {
546        x: TensorIr,
547        output_size: usize
548    },
549    shape = Shape::new([x.shape[0], x.shape[1], output_size]),
550    dtype = x.dtype
551);
552
553impl_ir_create!(
554    AdaptiveAvgPool2dOpIr {
555        x: TensorIr,
556        output_size: [usize; 2]
557    },
558    shape = Shape::new([x.shape[0], x.shape[1], output_size[0], output_size[1]]),
559    dtype = x.dtype
560);
561
562impl_ir_create!(
563    AdaptiveAvgPool1dBackwardOpIr {
564        x: TensorIr,
565        grad: TensorIr,
566    },
567    shape = x.shape.clone(),
568    dtype = x.dtype
569);
570
571impl_ir_create!(
572    AdaptiveAvgPool2dBackwardOpIr {
573        x: TensorIr,
574        grad: TensorIr,
575    },
576    shape = x.shape.clone(),
577    dtype = x.dtype
578);
579
580impl_ir_create!(
581    InterpolateOpIr {
582        x: TensorIr,
583        output_size: [usize; 2],
584        options: InterpolateOptionsIr
585    },
586    shape = Shape::new([x.shape[0], x.shape[1], output_size[0], output_size[1]]),
587    dtype = x.dtype
588);
589
590impl_ir_create!(
591    InterpolateBackwardOpIr {
592        x: TensorIr,
593        grad: TensorIr,
594        output_size: [usize; 2],
595        options: InterpolateOptionsIr
596    },
597    shape = x.shape.clone(),
598    dtype = x.dtype
599);
600
601impl_ir_create!(
602    GridSample2dOpIr {
603        tensor: TensorIr,
604        grid: TensorIr,
605        options: GridSampleOptionsIr
606    },
607    // Input tensor: [N, C, H_in, W_in]
608    // Grid: [N, H_out, W_out, 2]
609    // Output: [N, C, H_out, W_out]
610    shape = Shape::new([
611        tensor.shape[0],
612        tensor.shape[1],
613        grid.shape[1],
614        grid.shape[2]
615    ]),
616    dtype = tensor.dtype
617);
618
619impl_ir_create!(
620    Conv1dOpIr {
621        x: TensorIr,
622        weight: TensorIr,
623        bias: Option<TensorIr>,
624        options: Conv1dOptionsIr
625    },
626    shape = calculate_conv_output_shape(
627            &x.shape,
628            &weight.shape,
629            &options.stride,
630            &options.padding,
631            &options.dilation,
632        )
633        .unwrap(),
634    dtype = output_dtype(
635            [
636                Some(&x.dtype),
637                Some(&weight.dtype),
638                bias.as_ref().map(|b| &b.dtype),
639            ]
640            .iter()
641            .filter_map(|&d| d),
642        )
643        .unwrap()
644);
645
646impl_ir_create!(
647    Conv2dOpIr {
648        x: TensorIr,
649        weight: TensorIr,
650        bias: Option<TensorIr>,
651        options: Conv2dOptionsIr
652    },
653    shape = calculate_conv_output_shape(
654            &x.shape,
655            &weight.shape,
656            &options.stride,
657            &options.padding,
658            &options.dilation,
659        )
660        .unwrap(),
661    dtype = output_dtype(
662            [
663                Some(&x.dtype),
664                Some(&weight.dtype),
665                bias.as_ref().map(|b| &b.dtype),
666            ]
667            .iter()
668            .filter_map(|&d| d),
669        )
670        .unwrap()
671);
672
673impl_ir_create!(
674    Conv3dOpIr {
675        x: TensorIr,
676        weight: TensorIr,
677        bias: Option<TensorIr>,
678        options: Conv3dOptionsIr
679    },
680    shape = calculate_conv_output_shape(
681            &x.shape,
682            &weight.shape,
683            &options.stride,
684            &options.padding,
685            &options.dilation,
686        )
687        .unwrap(),
688    dtype = output_dtype(
689            [
690                Some(&x.dtype),
691                Some(&weight.dtype),
692                bias.as_ref().map(|b| &b.dtype),
693            ]
694            .iter()
695            .filter_map(|&d| d),
696        )
697        .unwrap()
698);
699
700impl_ir_create!(
701    DeformConv2dOpIr {
702        x: TensorIr,
703        offset: TensorIr,
704        weight: TensorIr,
705        mask: Option<TensorIr>,
706        bias: Option<TensorIr>,
707        options: DeformableConv2dOptionsIr
708    },
709    shape = calculate_conv_output_shape(
710            &x.shape,
711            &weight.shape,
712            &options.stride,
713            &options.padding,
714            &options.dilation,
715        )
716        .unwrap(),
717    dtype = output_dtype(
718            [
719                Some(&x.dtype),
720                Some(&offset.dtype),
721                Some(&weight.dtype),
722                mask.as_ref().map(|m| &m.dtype),
723                bias.as_ref().map(|b| &b.dtype),
724            ]
725            .iter()
726            .filter_map(|&d| d),
727        )
728        .unwrap()
729);
730
731impl_ir_create!(
732    ConvTranspose1dOpIr {
733        x: TensorIr,
734        weight: TensorIr,
735        bias: Option<TensorIr>,
736        options: ConvTranspose1dOptionsIr
737    },
738    shape = calculate_conv_transpose_output_shape(
739            &x.shape,
740            &weight.shape,
741            &options.stride,
742            &options.padding,
743            &options.padding_out,
744            &options.dilation,
745            options.groups,
746        )
747        .unwrap(),
748    dtype = output_dtype(
749            [
750                Some(&x.dtype),
751                Some(&weight.dtype),
752                bias.as_ref().map(|b| &b.dtype),
753            ]
754            .iter()
755            .filter_map(|&d| d),
756        )
757        .unwrap()
758);
759
760impl_ir_create!(
761    ConvTranspose2dOpIr {
762        x: TensorIr,
763        weight: TensorIr,
764        bias: Option<TensorIr>,
765        options: ConvTranspose2dOptionsIr
766    },
767    shape = calculate_conv_transpose_output_shape(
768            &x.shape,
769            &weight.shape,
770            &options.stride,
771            &options.padding,
772            &options.padding_out,
773            &options.dilation,
774            options.groups,
775        )
776        .unwrap(),
777    dtype = output_dtype(
778            [
779                Some(&x.dtype),
780                Some(&weight.dtype),
781                bias.as_ref().map(|b| &b.dtype),
782            ]
783            .iter()
784            .filter_map(|&d| d),
785        )
786        .unwrap()
787);
788
789impl_ir_create!(
790    ConvTranspose3dOpIr {
791        x: TensorIr,
792        weight: TensorIr,
793        bias: Option<TensorIr>,
794        options: ConvTranspose3dOptionsIr
795    },
796    shape = calculate_conv_transpose_output_shape(
797            &x.shape,
798            &weight.shape,
799            &options.stride,
800            &options.padding,
801            &options.padding_out,
802            &options.dilation,
803            options.groups,
804        )
805        .unwrap(),
806    dtype = output_dtype(
807            [
808                Some(&x.dtype),
809                Some(&weight.dtype),
810                bias.as_ref().map(|b| &b.dtype),
811            ]
812            .iter()
813            .filter_map(|&d| d),
814        )
815        .unwrap()
816);
817
818impl_ir_create!(
819    UnfoldOpIr {
820        input: TensorIr,
821        dim: usize,
822        size: usize,
823        step: usize
824    },
825    shape = calculate_unfold_shape(input.shape.clone(), dim, size, step),
826    dtype = input.dtype
827);
828
829impl_ir_create!(
830    CrossOpIr {
831        lhs: TensorIr,
832        rhs: TensorIr,
833        dim: usize
834    },
835    shape = lhs.shape.broadcast(&rhs.shape).unwrap(),
836    dtype = output_dtype([&lhs.dtype, &rhs.dtype]).unwrap()
837);
838
839impl_ir_create!(
840    QuantizeOpIr {
841        tensor: TensorIr,
842        qparams: QuantizationParametersIr,
843        scheme: QuantScheme
844    },
845    shape = tensor.shape.clone(),
846    dtype = DType::QFloat(scheme)
847);
848
849impl DequantizeOpIr {
850    pub fn create(input: TensorIr, dtype: DType, new_id: impl FnOnce() -> TensorId) -> Self {
851        let out = TensorIr::uninit(new_id(), input.shape.clone(), dtype);
852
853        DequantizeOpIr { input, out }
854    }
855}
856
857// Operations with multiple outputs
858
859impl ReduceDimWithIndicesOpIr {
860    pub fn create(
861        tensor: TensorIr,
862        dim: usize,
863        dtype_indices: DType,
864        mut new_id: impl FnMut() -> TensorId,
865    ) -> Self {
866        let mut shape = tensor.shape.clone();
867        shape[dim] = 1;
868        let out = TensorIr::uninit(new_id(), shape.clone(), tensor.dtype);
869        let out_indices = TensorIr::uninit(new_id(), shape.clone(), dtype_indices);
870
871        ReduceDimWithIndicesOpIr {
872            tensor,
873            dim,
874            out,
875            out_indices,
876        }
877    }
878}
879
880impl DeformConv2dBackwardOpIr {
881    #[allow(clippy::too_many_arguments)]
882    pub fn create(
883        x: TensorIr,
884        offset: TensorIr,
885        weight: TensorIr,
886        mask: Option<TensorIr>,
887        bias: Option<TensorIr>,
888        out_grad: TensorIr,
889        options: DeformableConv2dOptionsIr,
890        mut new_id: impl FnMut() -> TensorId,
891    ) -> Self {
892        let dtype = output_dtype(
893            [
894                Some(&x.dtype),
895                Some(&weight.dtype),
896                mask.as_ref().map(|m| &m.dtype),
897                bias.as_ref().map(|b| &b.dtype),
898            ]
899            .iter()
900            .filter_map(|&d| d),
901        )
902        .unwrap();
903
904        let input_grad = TensorIr::uninit(new_id(), x.shape.clone(), dtype);
905        let offset_grad = TensorIr::uninit(new_id(), offset.shape.clone(), dtype);
906        let weight_grad = TensorIr::uninit(new_id(), weight.shape.clone(), dtype);
907        let mask_grad = mask
908            .as_ref()
909            .map(|t| TensorIr::uninit(new_id(), t.shape.clone(), dtype));
910        let bias_grad = bias
911            .as_ref()
912            .map(|t| TensorIr::uninit(new_id(), t.shape.clone(), dtype));
913
914        DeformConv2dBackwardOpIr {
915            x,
916            offset,
917            weight,
918            mask,
919            bias,
920            out_grad,
921            options,
922            input_grad,
923            offset_grad,
924            weight_grad,
925            mask_grad,
926            bias_grad,
927        }
928    }
929}
930
931impl MaxPool1dWithIndicesOpIr {
932    #[allow(clippy::too_many_arguments)]
933    pub fn create(
934        x: TensorIr,
935        kernel_size: usize,
936        stride: usize,
937        padding: usize,
938        dilation: usize,
939        ceil_mode: bool,
940        dtype_indices: DType,
941        mut new_id: impl FnMut() -> TensorId,
942    ) -> Self {
943        let shape = calculate_pool_output_shape(
944            &x.shape,
945            &[kernel_size],
946            &[stride],
947            &[padding],
948            &[dilation],
949            ceil_mode,
950        )
951        .unwrap();
952        let out = TensorIr::uninit(new_id(), shape.clone(), x.dtype);
953        let out_indices = TensorIr::uninit(new_id(), shape, dtype_indices);
954
955        MaxPool1dWithIndicesOpIr {
956            x,
957            kernel_size,
958            stride,
959            padding,
960            dilation,
961            ceil_mode,
962            out,
963            out_indices,
964        }
965    }
966}
967
968impl MaxPool2dWithIndicesOpIr {
969    #[allow(clippy::too_many_arguments)]
970    pub fn create(
971        x: TensorIr,
972        kernel_size: [usize; 2],
973        stride: [usize; 2],
974        padding: [usize; 2],
975        dilation: [usize; 2],
976        ceil_mode: bool,
977        dtype_indices: DType,
978        mut new_id: impl FnMut() -> TensorId,
979    ) -> Self {
980        let shape = calculate_pool_output_shape(
981            &x.shape,
982            &kernel_size,
983            &stride,
984            &padding,
985            &dilation,
986            ceil_mode,
987        )
988        .unwrap();
989        let out = TensorIr::uninit(new_id(), shape.clone(), x.dtype);
990        let out_indices = TensorIr::uninit(new_id(), shape, dtype_indices);
991
992        MaxPool2dWithIndicesOpIr {
993            x,
994            kernel_size,
995            stride,
996            padding,
997            dilation,
998            ceil_mode,
999            out,
1000            out_indices,
1001        }
1002    }
1003}