rlx_ir/ops/
normalization.rs1use crate::{Graph, NodeId, Op, Shape};
19
20impl Graph {
21 pub fn layer_norm2d(&mut self, input: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId {
23 let shape = self.node(input).shape.clone();
24 self.push(
25 Op::LayerNorm2d { eps },
26 vec![input, gamma, beta],
27 shape,
28 None,
29 )
30 }
31
32 pub fn group_norm(
34 &mut self,
35 input: NodeId,
36 gamma: NodeId,
37 beta: NodeId,
38 num_groups: usize,
39 eps: f32,
40 ) -> NodeId {
41 let shape = self.node(input).shape.clone();
42 self.push(
43 Op::GroupNorm { num_groups, eps },
44 vec![input, gamma, beta],
45 shape,
46 None,
47 )
48 }
49
50 pub fn batch_norm_inference(
52 &mut self,
53 input: NodeId,
54 gamma: NodeId,
55 beta: NodeId,
56 running_mean: NodeId,
57 running_var: NodeId,
58 eps: f32,
59 ) -> NodeId {
60 let shape = self.node(input).shape.clone();
61 self.push(
62 Op::BatchNormInference { eps },
63 vec![input, gamma, beta, running_mean, running_var],
64 shape,
65 None,
66 )
67 }
68
69 pub fn layer_norm(
71 &mut self,
72 input: NodeId,
73 gamma: NodeId,
74 beta: NodeId,
75 axis: i32,
76 eps: f32,
77 shape: Shape,
78 ) -> NodeId {
79 self.push(
80 Op::LayerNorm { axis, eps },
81 vec![input, gamma, beta],
82 shape,
83 None,
84 )
85 }
86
87 pub fn fused_residual_ln(
89 &mut self,
90 x: NodeId,
91 residual: NodeId,
92 bias: Option<NodeId>,
93 gamma: NodeId,
94 beta: NodeId,
95 eps: f32,
96 shape: Shape,
97 ) -> NodeId {
98 let has_bias = bias.is_some();
99 let mut inputs = vec![x, residual];
100 if let Some(b) = bias {
101 inputs.push(b);
102 }
103 inputs.push(gamma);
104 inputs.push(beta);
105 self.push(Op::FusedResidualLN { has_bias, eps }, inputs, shape, None)
106 }
107
108 pub fn fused_residual_rms_norm(
110 &mut self,
111 x: NodeId,
112 residual: NodeId,
113 bias: Option<NodeId>,
114 gamma: NodeId,
115 beta: NodeId,
116 eps: f32,
117 shape: Shape,
118 ) -> NodeId {
119 let has_bias = bias.is_some();
120 let mut inputs = vec![x, residual];
121 if let Some(b) = bias {
122 inputs.push(b);
123 }
124 inputs.push(gamma);
125 inputs.push(beta);
126 self.push(
127 Op::FusedResidualRmsNorm { has_bias, eps },
128 inputs,
129 shape,
130 None,
131 )
132 }
133}