Skip to main content

rlx_ir/ops/
conv2d.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! NCHW convolution builders (`conv2d`, `conv_transpose2d`).
17
18use crate::{Graph, NodeId, Op};
19
20impl Graph {
21    /// 2D convolution on NCHW tensors (`Op::Conv`). Weight `[C_out, C_in/g, kH, kW]`.
22    pub fn conv2d(
23        &mut self,
24        input: NodeId,
25        weight: NodeId,
26        kernel_size: [usize; 2],
27        stride: [usize; 2],
28        padding: [usize; 2],
29        dilation: [usize; 2],
30        groups: usize,
31    ) -> NodeId {
32        let in_s = self.node(input).shape.clone();
33        let w_s = self.node(weight).shape.clone();
34        let out = crate::shape::conv2d_output_shape(
35            &in_s,
36            &w_s,
37            kernel_size,
38            stride,
39            padding,
40            dilation,
41            groups,
42        )
43        .expect("conv2d shape inference");
44        self.push(
45            Op::Conv {
46                kernel_size: kernel_size.to_vec(),
47                stride: stride.to_vec(),
48                padding: padding.to_vec(),
49                dilation: dilation.to_vec(),
50                groups,
51            },
52            vec![input, weight],
53            out,
54            None,
55        )
56    }
57
58    /// NCHW im2col (`Op::Im2Col`). Output `[M, C·kH·kW]`.
59    pub fn im2col(
60        &mut self,
61        input: NodeId,
62        kernel_size: [usize; 2],
63        stride: [usize; 2],
64        padding: [usize; 2],
65        dilation: [usize; 2],
66    ) -> NodeId {
67        let in_s = self.node(input).shape.clone();
68        let out = crate::shape::im2col_output_shape(&in_s, kernel_size, stride, padding, dilation)
69            .expect("im2col shape inference");
70        self.push(
71            Op::Im2Col {
72                kernel_size: kernel_size.to_vec(),
73                stride: stride.to_vec(),
74                padding: padding.to_vec(),
75                dilation: dilation.to_vec(),
76            },
77            vec![input],
78            out,
79            None,
80        )
81    }
82
83    /// 2D transposed convolution on NCHW. Weight `[C_in, C_out/g, kH, kW]`.
84    pub fn conv_transpose2d(
85        &mut self,
86        input: NodeId,
87        weight: NodeId,
88        kernel_size: [usize; 2],
89        stride: [usize; 2],
90        padding: [usize; 2],
91        dilation: [usize; 2],
92        output_padding: [usize; 2],
93        groups: usize,
94    ) -> NodeId {
95        let in_s = self.node(input).shape.clone();
96        let w_s = self.node(weight).shape.clone();
97        let out = crate::shape::conv_transpose2d_output_shape(
98            &in_s,
99            &w_s,
100            kernel_size,
101            stride,
102            padding,
103            dilation,
104            output_padding,
105            groups,
106        )
107        .expect("conv_transpose2d shape inference");
108        self.push(
109            Op::ConvTranspose2d {
110                kernel_size: kernel_size.to_vec(),
111                stride: stride.to_vec(),
112                padding: padding.to_vec(),
113                dilation: dilation.to_vec(),
114                output_padding: output_padding.to_vec(),
115                groups,
116            },
117            vec![input, weight],
118            out,
119            None,
120        )
121    }
122}