Skip to main content

rlx_ir/ops/
normalization.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Normalization builders: LayerNorm, fused residual+LN (plan #53).
17
18use crate::{Graph, NodeId, Op, Shape};
19
20impl Graph {
21    /// LayerNorm2d on NCHW (normalize across channels at each spatial position).
22    pub fn layer_norm2d(&mut self, input: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId {
23        let shape = self.node(input).shape.clone();
24        self.push(
25            Op::LayerNorm2d { eps },
26            vec![input, gamma, beta],
27            shape,
28            None,
29        )
30    }
31
32    /// Group normalization on NCHW.
33    pub fn group_norm(
34        &mut self,
35        input: NodeId,
36        gamma: NodeId,
37        beta: NodeId,
38        num_groups: usize,
39        eps: f32,
40    ) -> NodeId {
41        let shape = self.node(input).shape.clone();
42        self.push(
43            Op::GroupNorm { num_groups, eps },
44            vec![input, gamma, beta],
45            shape,
46            None,
47        )
48    }
49
50    /// BatchNorm inference (frozen running mean/variance).
51    pub fn batch_norm_inference(
52        &mut self,
53        input: NodeId,
54        gamma: NodeId,
55        beta: NodeId,
56        running_mean: NodeId,
57        running_var: NodeId,
58        eps: f32,
59    ) -> NodeId {
60        let shape = self.node(input).shape.clone();
61        self.push(
62            Op::BatchNormInference { eps },
63            vec![input, gamma, beta, running_mean, running_var],
64            shape,
65            None,
66        )
67    }
68
69    /// Layer normalization.
70    pub fn layer_norm(
71        &mut self,
72        input: NodeId,
73        gamma: NodeId,
74        beta: NodeId,
75        axis: i32,
76        eps: f32,
77        shape: Shape,
78    ) -> NodeId {
79        self.push(
80            Op::LayerNorm { axis, eps },
81            vec![input, gamma, beta],
82            shape,
83            None,
84        )
85    }
86
87    /// Fused residual + bias + layer norm (created by optimization passes).
88    pub fn fused_residual_ln(
89        &mut self,
90        x: NodeId,
91        residual: NodeId,
92        bias: Option<NodeId>,
93        gamma: NodeId,
94        beta: NodeId,
95        eps: f32,
96        shape: Shape,
97    ) -> NodeId {
98        let has_bias = bias.is_some();
99        let mut inputs = vec![x, residual];
100        if let Some(b) = bias {
101            inputs.push(b);
102        }
103        inputs.push(gamma);
104        inputs.push(beta);
105        self.push(Op::FusedResidualLN { has_bias, eps }, inputs, shape, None)
106    }
107
108    /// Fused residual + bias + RMS norm (created by optimization passes).
109    pub fn fused_residual_rms_norm(
110        &mut self,
111        x: NodeId,
112        residual: NodeId,
113        bias: Option<NodeId>,
114        gamma: NodeId,
115        beta: NodeId,
116        eps: f32,
117        shape: Shape,
118    ) -> NodeId {
119        let has_bias = bias.is_some();
120        let mut inputs = vec![x, residual];
121        if let Some(b) = bias {
122            inputs.push(b);
123        }
124        inputs.push(gamma);
125        inputs.push(beta);
126        self.push(
127            Op::FusedResidualRmsNorm { has_bias, eps },
128            inputs,
129            shape,
130            None,
131        )
132    }
133}