use crate::{Graph, NodeId, Op, Shape};
impl Graph {
pub fn layer_norm2d(&mut self, input: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId {
let shape = self.node(input).shape.clone();
self.push(
Op::LayerNorm2d { eps },
vec![input, gamma, beta],
shape,
None,
)
}
pub fn group_norm(
&mut self,
input: NodeId,
gamma: NodeId,
beta: NodeId,
num_groups: usize,
eps: f32,
) -> NodeId {
let shape = self.node(input).shape.clone();
self.push(
Op::GroupNorm { num_groups, eps },
vec![input, gamma, beta],
shape,
None,
)
}
pub fn layer_norm(
&mut self,
input: NodeId,
gamma: NodeId,
beta: NodeId,
axis: i32,
eps: f32,
shape: Shape,
) -> NodeId {
self.push(
Op::LayerNorm { axis, eps },
vec![input, gamma, beta],
shape,
None,
)
}
pub fn fused_residual_ln(
&mut self,
x: NodeId,
residual: NodeId,
bias: Option<NodeId>,
gamma: NodeId,
beta: NodeId,
eps: f32,
shape: Shape,
) -> NodeId {
let has_bias = bias.is_some();
let mut inputs = vec![x, residual];
if let Some(b) = bias {
inputs.push(b);
}
inputs.push(gamma);
inputs.push(beta);
self.push(Op::FusedResidualLN { has_bias, eps }, inputs, shape, None)
}
pub fn fused_residual_rms_norm(
&mut self,
x: NodeId,
residual: NodeId,
bias: Option<NodeId>,
gamma: NodeId,
beta: NodeId,
eps: f32,
shape: Shape,
) -> NodeId {
let has_bias = bias.is_some();
let mut inputs = vec![x, residual];
if let Some(b) = bias {
inputs.push(b);
}
inputs.push(gamma);
inputs.push(beta);
self.push(
Op::FusedResidualRmsNorm { has_bias, eps },
inputs,
shape,
None,
)
}
}