Skip to main content

rlx_ir/ops/
conv2d.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4//! NCHW convolution builders (`conv2d`, `conv_transpose2d`).
5
6use crate::{Graph, NodeId, Op};
7
8impl Graph {
9    /// 2D convolution on NCHW tensors (`Op::Conv`). Weight `[C_out, C_in/g, kH, kW]`.
10    pub fn conv2d(
11        &mut self,
12        input: NodeId,
13        weight: NodeId,
14        kernel_size: [usize; 2],
15        stride: [usize; 2],
16        padding: [usize; 2],
17        dilation: [usize; 2],
18        groups: usize,
19    ) -> NodeId {
20        let in_s = self.node(input).shape.clone();
21        let w_s = self.node(weight).shape.clone();
22        let out = crate::shape::conv2d_output_shape(
23            &in_s,
24            &w_s,
25            kernel_size,
26            stride,
27            padding,
28            dilation,
29            groups,
30        )
31        .expect("conv2d shape inference");
32        self.push(
33            Op::Conv {
34                kernel_size: kernel_size.to_vec(),
35                stride: stride.to_vec(),
36                padding: padding.to_vec(),
37                dilation: dilation.to_vec(),
38                groups,
39            },
40            vec![input, weight],
41            out,
42            None,
43        )
44    }
45
46    /// 2D transposed convolution on NCHW. Weight `[C_in, C_out/g, kH, kW]`.
47    pub fn conv_transpose2d(
48        &mut self,
49        input: NodeId,
50        weight: NodeId,
51        kernel_size: [usize; 2],
52        stride: [usize; 2],
53        padding: [usize; 2],
54        dilation: [usize; 2],
55        output_padding: [usize; 2],
56        groups: usize,
57    ) -> NodeId {
58        let in_s = self.node(input).shape.clone();
59        let w_s = self.node(weight).shape.clone();
60        let out = crate::shape::conv_transpose2d_output_shape(
61            &in_s,
62            &w_s,
63            kernel_size,
64            stride,
65            padding,
66            dilation,
67            output_padding,
68            groups,
69        )
70        .expect("conv_transpose2d shape inference");
71        self.push(
72            Op::ConvTranspose2d {
73                kernel_size: kernel_size.to_vec(),
74                stride: stride.to_vec(),
75                padding: padding.to_vec(),
76                dilation: dilation.to_vec(),
77                output_padding: output_padding.to_vec(),
78                groups,
79            },
80            vec![input, weight],
81            out,
82            None,
83        )
84    }
85}