use serde::{Deserialize, Serialize};
use tensorlogic_ir::{EinsumGraph, EinsumNode};
use crate::error::{Result, TrustformerError};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct LayerNormConfig {
pub normalized_shape: usize,
pub eps: f64,
pub elementwise_affine: bool,
}
impl LayerNormConfig {
pub fn new(normalized_shape: usize) -> Self {
Self {
normalized_shape,
eps: 1e-5,
elementwise_affine: true,
}
}
pub fn with_eps(mut self, eps: f64) -> Self {
self.eps = eps;
self
}
pub fn with_elementwise_affine(mut self, elementwise_affine: bool) -> Self {
self.elementwise_affine = elementwise_affine;
self
}
pub fn validate(&self) -> Result<()> {
if self.normalized_shape == 0 {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: "normalized_shape must be positive".to_string(),
});
}
if self.eps <= 0.0 {
return Err(TrustformerError::InvalidDimension {
expected: 1,
got: 0,
context: format!("eps must be positive, got {}", self.eps),
});
}
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct LayerNorm {
pub config: LayerNormConfig,
}
impl LayerNorm {
pub fn new(config: LayerNormConfig) -> Result<Self> {
config.validate()?;
Ok(Self { config })
}
pub fn build_layernorm_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
let mean_tensor = graph.add_tensor("ln_mean");
let mean_node = EinsumNode::reduce("mean", vec![2], 0, mean_tensor); graph.add_node(mean_node)?;
let centered_tensor = graph.add_tensor("ln_centered");
let center_node = EinsumNode::elem_binary("sub", 0, mean_tensor, centered_tensor);
graph.add_node(center_node)?;
let squared_tensor = graph.add_tensor("ln_squared");
let square_node =
EinsumNode::elem_binary("mul", centered_tensor, centered_tensor, squared_tensor);
graph.add_node(square_node)?;
let var_tensor = graph.add_tensor("ln_var");
let var_node = EinsumNode::reduce("mean", vec![2], squared_tensor, var_tensor);
graph.add_node(var_node)?;
let var_eps_tensor = graph.add_tensor("ln_var_eps");
let eps_const_tensor = graph.add_tensor("eps_const");
let eps_node = EinsumNode::elem_binary("add", var_tensor, eps_const_tensor, var_eps_tensor);
graph.add_node(eps_node)?;
let std_tensor = graph.add_tensor("ln_std");
let sqrt_node = EinsumNode::elem_unary("sqrt", var_eps_tensor, std_tensor);
graph.add_node(sqrt_node)?;
let normalized_tensor = graph.add_tensor("ln_normalized");
let norm_node =
EinsumNode::elem_binary("div", centered_tensor, std_tensor, normalized_tensor);
graph.add_node(norm_node)?;
if self.config.elementwise_affine {
let scaled_tensor = graph.add_tensor("ln_scaled");
let scale_node = EinsumNode::elem_binary("mul", normalized_tensor, 1, scaled_tensor);
graph.add_node(scale_node)?;
let output_tensor = graph.add_tensor("ln_output");
let shift_node = EinsumNode::elem_binary("add", scaled_tensor, 2, output_tensor);
graph.add_node(shift_node)?;
Ok(vec![output_tensor])
} else {
Ok(vec![normalized_tensor])
}
}
pub fn eps(&self) -> f64 {
self.config.eps
}
pub fn has_elementwise_affine(&self) -> bool {
self.config.elementwise_affine
}
}
#[derive(Clone, Debug)]
pub struct RMSNorm {
pub config: LayerNormConfig,
}
impl RMSNorm {
pub fn new(config: LayerNormConfig) -> Result<Self> {
config.validate()?;
Ok(Self { config })
}
pub fn build_rmsnorm_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
let squared_tensor = graph.add_tensor("rms_squared");
let square_node = EinsumNode::elem_binary("mul", 0, 0, squared_tensor);
graph.add_node(square_node)?;
let mean_sq_tensor = graph.add_tensor("rms_mean_sq");
let mean_node = EinsumNode::reduce("mean", vec![2], squared_tensor, mean_sq_tensor);
graph.add_node(mean_node)?;
let mean_sq_eps_tensor = graph.add_tensor("rms_mean_sq_eps");
let eps_const_tensor = graph.add_tensor("eps_const");
let eps_node =
EinsumNode::elem_binary("add", mean_sq_tensor, eps_const_tensor, mean_sq_eps_tensor);
graph.add_node(eps_node)?;
let rms_tensor = graph.add_tensor("rms");
let sqrt_node = EinsumNode::elem_unary("sqrt", mean_sq_eps_tensor, rms_tensor);
graph.add_node(sqrt_node)?;
let normalized_tensor = graph.add_tensor("rms_normalized");
let norm_node = EinsumNode::elem_binary("div", 0, rms_tensor, normalized_tensor);
graph.add_node(norm_node)?;
if self.config.elementwise_affine {
let output_tensor = graph.add_tensor("rms_output");
let scale_node = EinsumNode::elem_binary("mul", normalized_tensor, 1, output_tensor);
graph.add_node(scale_node)?;
Ok(vec![output_tensor])
} else {
Ok(vec![normalized_tensor])
}
}
pub fn eps(&self) -> f64 {
self.config.eps
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_layernorm_config_creation() {
let config = LayerNormConfig::new(512);
assert_eq!(config.normalized_shape, 512);
assert!((config.eps - 1e-5).abs() < 1e-10);
assert!(config.elementwise_affine);
assert!(config.validate().is_ok());
}
#[test]
fn test_layernorm_config_with_eps() {
let config = LayerNormConfig::new(512).with_eps(1e-6);
assert!((config.eps - 1e-6).abs() < 1e-10);
assert!(config.validate().is_ok());
}
#[test]
fn test_layernorm_config_without_affine() {
let config = LayerNormConfig::new(512).with_elementwise_affine(false);
assert!(!config.elementwise_affine);
}
#[test]
fn test_layernorm_creation() {
let config = LayerNormConfig::new(512);
let ln = LayerNorm::new(config).expect("unwrap");
assert_eq!(ln.config.normalized_shape, 512);
assert!(ln.has_elementwise_affine());
}
#[test]
fn test_layernorm_graph_building_with_affine() {
let config = LayerNormConfig::new(512);
let ln = LayerNorm::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("x");
graph.add_tensor("gamma");
graph.add_tensor("beta");
let outputs = ln.build_layernorm_graph(&mut graph).expect("unwrap");
assert_eq!(outputs.len(), 1);
assert!(!graph.nodes.is_empty());
}
#[test]
fn test_layernorm_graph_building_without_affine() {
let config = LayerNormConfig::new(512).with_elementwise_affine(false);
let ln = LayerNorm::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("x");
let outputs = ln.build_layernorm_graph(&mut graph).expect("unwrap");
assert_eq!(outputs.len(), 1);
assert!(!graph.nodes.is_empty());
}
#[test]
fn test_rmsnorm_creation() {
let config = LayerNormConfig::new(512);
let rms = RMSNorm::new(config).expect("unwrap");
assert_eq!(rms.config.normalized_shape, 512);
}
#[test]
fn test_rmsnorm_graph_building() {
let config = LayerNormConfig::new(512);
let rms = RMSNorm::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("x");
graph.add_tensor("gamma");
let outputs = rms.build_rmsnorm_graph(&mut graph).expect("unwrap");
assert_eq!(outputs.len(), 1);
assert!(!graph.nodes.is_empty());
}
#[test]
fn test_invalid_config_zero_shape() {
let config = LayerNormConfig::new(0);
assert!(config.validate().is_err());
}
#[test]
fn test_invalid_config_negative_eps() {
let config = LayerNormConfig::new(512).with_eps(-1e-5);
assert!(config.validate().is_err());
}
#[test]
fn test_layernorm_eps() {
let config = LayerNormConfig::new(512).with_eps(1e-6);
let ln = LayerNorm::new(config).expect("unwrap");
assert!((ln.eps() - 1e-6).abs() < 1e-10);
}
}