Skip to main content

rlx_ir/ops/
normalization.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Normalization builders: LayerNorm, fused residual+LN (plan #53).
17
18use crate::{Graph, NodeId, Op, Shape};
19
20impl Graph {
21    /// LayerNorm2d on NCHW (normalize across channels at each spatial position).
22    pub fn layer_norm2d(&mut self, input: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId {
23        let shape = self.node(input).shape.clone();
24        self.push(
25            Op::LayerNorm2d { eps },
26            vec![input, gamma, beta],
27            shape,
28            None,
29        )
30    }
31
32    /// Group normalization on NCHW.
33    pub fn group_norm(
34        &mut self,
35        input: NodeId,
36        gamma: NodeId,
37        beta: NodeId,
38        num_groups: usize,
39        eps: f32,
40    ) -> NodeId {
41        let shape = self.node(input).shape.clone();
42        self.push(
43            Op::GroupNorm { num_groups, eps },
44            vec![input, gamma, beta],
45            shape,
46            None,
47        )
48    }
49
50    /// Layer normalization.
51    pub fn layer_norm(
52        &mut self,
53        input: NodeId,
54        gamma: NodeId,
55        beta: NodeId,
56        axis: i32,
57        eps: f32,
58        shape: Shape,
59    ) -> NodeId {
60        self.push(
61            Op::LayerNorm { axis, eps },
62            vec![input, gamma, beta],
63            shape,
64            None,
65        )
66    }
67
68    /// Fused residual + bias + layer norm (created by optimization passes).
69    pub fn fused_residual_ln(
70        &mut self,
71        x: NodeId,
72        residual: NodeId,
73        bias: Option<NodeId>,
74        gamma: NodeId,
75        beta: NodeId,
76        eps: f32,
77        shape: Shape,
78    ) -> NodeId {
79        let has_bias = bias.is_some();
80        let mut inputs = vec![x, residual];
81        if let Some(b) = bias {
82            inputs.push(b);
83        }
84        inputs.push(gamma);
85        inputs.push(beta);
86        self.push(Op::FusedResidualLN { has_bias, eps }, inputs, shape, None)
87    }
88
89    /// Fused residual + bias + RMS norm (created by optimization passes).
90    pub fn fused_residual_rms_norm(
91        &mut self,
92        x: NodeId,
93        residual: NodeId,
94        bias: Option<NodeId>,
95        gamma: NodeId,
96        beta: NodeId,
97        eps: f32,
98        shape: Shape,
99    ) -> NodeId {
100        let has_bias = bias.is_some();
101        let mut inputs = vec![x, residual];
102        if let Some(b) = bias {
103            inputs.push(b);
104        }
105        inputs.push(gamma);
106        inputs.push(beta);
107        self.push(
108            Op::FusedResidualRmsNorm { has_bias, eps },
109            inputs,
110            shape,
111            None,
112        )
113    }
114}