rlx_ir/ops/
normalization.rs1use crate::{Graph, NodeId, Op, Shape};
19
20impl Graph {
21 pub fn layer_norm2d(&mut self, input: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId {
23 let shape = self.node(input).shape.clone();
24 self.push(
25 Op::LayerNorm2d { eps },
26 vec![input, gamma, beta],
27 shape,
28 None,
29 )
30 }
31
32 pub fn group_norm(
34 &mut self,
35 input: NodeId,
36 gamma: NodeId,
37 beta: NodeId,
38 num_groups: usize,
39 eps: f32,
40 ) -> NodeId {
41 let shape = self.node(input).shape.clone();
42 self.push(
43 Op::GroupNorm { num_groups, eps },
44 vec![input, gamma, beta],
45 shape,
46 None,
47 )
48 }
49
50 pub fn layer_norm(
52 &mut self,
53 input: NodeId,
54 gamma: NodeId,
55 beta: NodeId,
56 axis: i32,
57 eps: f32,
58 shape: Shape,
59 ) -> NodeId {
60 self.push(
61 Op::LayerNorm { axis, eps },
62 vec![input, gamma, beta],
63 shape,
64 None,
65 )
66 }
67
68 pub fn fused_residual_ln(
70 &mut self,
71 x: NodeId,
72 residual: NodeId,
73 bias: Option<NodeId>,
74 gamma: NodeId,
75 beta: NodeId,
76 eps: f32,
77 shape: Shape,
78 ) -> NodeId {
79 let has_bias = bias.is_some();
80 let mut inputs = vec![x, residual];
81 if let Some(b) = bias {
82 inputs.push(b);
83 }
84 inputs.push(gamma);
85 inputs.push(beta);
86 self.push(Op::FusedResidualLN { has_bias, eps }, inputs, shape, None)
87 }
88
89 pub fn fused_residual_rms_norm(
91 &mut self,
92 x: NodeId,
93 residual: NodeId,
94 bias: Option<NodeId>,
95 gamma: NodeId,
96 beta: NodeId,
97 eps: f32,
98 shape: Shape,
99 ) -> NodeId {
100 let has_bias = bias.is_some();
101 let mut inputs = vec![x, residual];
102 if let Some(b) = bias {
103 inputs.push(b);
104 }
105 inputs.push(gamma);
106 inputs.push(beta);
107 self.push(
108 Op::FusedResidualRmsNorm { has_bias, eps },
109 inputs,
110 shape,
111 None,
112 )
113 }
114}