use super::ops::*;
use crate::autograd::Var;
use crate::error::Result;
use crate::ops::{NormalizationOps, ScalarOps, TensorOps};
use crate::runtime::{Runtime, RuntimeClient};
use std::sync::Arc;
pub fn var_rms_norm<R, C>(input: &Var<R>, weight: &Var<R>, eps: f32, client: &C) -> Result<Var<R>>
where
R: Runtime,
C: RuntimeClient<R> + NormalizationOps<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
{
let output = client.rms_norm(input.tensor(), weight.tensor(), eps)?;
if input.requires_grad() || weight.requires_grad() {
let grad_fn = RmsNormBackward::<R>::new(
input.id(),
weight.id(),
input.tensor().clone(),
weight.tensor().clone(),
eps,
input.grad_fn().cloned(),
weight.grad_fn().cloned(),
);
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
pub fn var_layer_norm<R, C>(
input: &Var<R>,
weight: &Var<R>,
bias: &Var<R>,
eps: f32,
client: &C,
) -> Result<Var<R>>
where
R: Runtime,
C: RuntimeClient<R> + NormalizationOps<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
{
let output = client.layer_norm(input.tensor(), weight.tensor(), bias.tensor(), eps)?;
if input.requires_grad() || weight.requires_grad() || bias.requires_grad() {
let grad_fn = LayerNormBackward::<R>::new(
input.id(),
weight.id(),
bias.id(),
input.tensor().clone(),
weight.tensor().clone(),
eps,
input.grad_fn().cloned(),
weight.grad_fn().cloned(),
bias.grad_fn().cloned(),
);
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
pub fn var_group_norm<R, C>(
input: &Var<R>,
weight: &Var<R>,
bias: &Var<R>,
num_groups: usize,
eps: f32,
client: &C,
) -> Result<Var<R>>
where
R: Runtime,
C: RuntimeClient<R> + NormalizationOps<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
{
let output = client.group_norm(
input.tensor(),
weight.tensor(),
bias.tensor(),
num_groups,
eps,
)?;
if input.requires_grad() || weight.requires_grad() || bias.requires_grad() {
let grad_fn = GroupNormBackward::<R>::new(
input.id(),
weight.id(),
bias.id(),
input.tensor().clone(),
weight.tensor().clone(),
num_groups,
eps,
input.grad_fn().cloned(),
weight.grad_fn().cloned(),
bias.grad_fn().cloned(),
);
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
pub fn var_fused_add_rms_norm<R, C>(
x: &Var<R>,
residual: &Var<R>,
weight: &Var<R>,
eps: f32,
client: &C,
) -> Result<Var<R>>
where
R: Runtime,
C: RuntimeClient<R> + NormalizationOps<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
{
let (output, pre_norm) =
client.fused_add_rms_norm(x.tensor(), residual.tensor(), weight.tensor(), eps)?;
if x.requires_grad() || residual.requires_grad() || weight.requires_grad() {
let grad_fn = FusedAddRmsNormBackward::<R>::new(
x.id(),
residual.id(),
weight.id(),
pre_norm,
weight.tensor().clone(),
eps,
x.grad_fn().cloned(),
residual.grad_fn().cloned(),
weight.grad_fn().cloned(),
);
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
pub fn var_fused_add_layer_norm<R, C>(
x: &Var<R>,
residual: &Var<R>,
weight: &Var<R>,
bias: &Var<R>,
eps: f32,
client: &C,
) -> Result<Var<R>>
where
R: Runtime,
C: RuntimeClient<R> + NormalizationOps<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
{
let (output, pre_norm) = client.fused_add_layer_norm(
x.tensor(),
residual.tensor(),
weight.tensor(),
bias.tensor(),
eps,
)?;
if x.requires_grad()
|| residual.requires_grad()
|| weight.requires_grad()
|| bias.requires_grad()
{
let grad_fn = FusedAddLayerNormBackward::<R>::new(
x.id(),
residual.id(),
weight.id(),
bias.id(),
pre_norm,
weight.tensor().clone(),
bias.tensor().clone(),
eps,
x.grad_fn().cloned(),
residual.grad_fn().cloned(),
weight.grad_fn().cloned(),
bias.grad_fn().cloned(),
);
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::backward;
use crate::runtime::cpu::{CpuDevice, CpuRuntime};
use crate::tensor::Tensor;
#[test]
fn test_var_rms_norm_forward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let input = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[1, 4], &device),
true,
);
let weight = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[4], &device),
true,
);
let result = var_rms_norm(&input, &weight, 1e-5, &client).unwrap();
let data: Vec<f32> = result.tensor().to_vec();
let rms = (7.5f32 + 1e-5).sqrt();
for i in 0..4 {
let expected = (i as f32 + 1.0) / rms;
assert!(
(data[i] - expected).abs() < 1e-5,
"data[{}] = {}, expected {}",
i,
data[i],
expected,
);
}
}
#[test]
fn test_var_rms_norm_backward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let input = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[1, 3], &device),
true,
);
let weight = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0], &[3], &device),
true,
);
let output = var_rms_norm(&input, &weight, 1e-5, &client).unwrap();
let loss = crate::autograd::var_sum(&output, &[0, 1], false, &client).unwrap();
let grads = backward(&loss, &client).unwrap();
let grad_input = grads.get(input.id()).unwrap();
let grad_weight = grads.get(weight.id()).unwrap();
let gi: Vec<f32> = grad_input.to_vec();
let gw: Vec<f32> = grad_weight.to_vec();
assert_eq!(gi.len(), 3);
assert_eq!(gw.len(), 3);
for val in &gi {
assert!(val.is_finite(), "input gradient should be finite");
}
for val in &gw {
assert!(val.is_finite(), "weight gradient should be finite");
}
}
#[test]
fn test_var_layer_norm_forward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let input = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[1, 4], &device),
true,
);
let weight = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[4], &device),
true,
);
let bias = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.0f32, 0.0, 0.0, 0.0], &[4], &device),
true,
);
let result = var_layer_norm(&input, &weight, &bias, 1e-5, &client).unwrap();
let data: Vec<f32> = result.tensor().to_vec();
let sum: f32 = data.iter().sum();
assert!(sum.abs() < 1e-4, "layer norm output should have ~0 mean");
}
#[test]
fn test_var_layer_norm_backward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let input = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[1, 3], &device),
true,
);
let weight = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0], &[3], &device),
true,
);
let bias = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.0f32, 0.0, 0.0], &[3], &device),
true,
);
let output = var_layer_norm(&input, &weight, &bias, 1e-5, &client).unwrap();
let loss = crate::autograd::var_sum(&output, &[0, 1], false, &client).unwrap();
let grads = backward(&loss, &client).unwrap();
let grad_input = grads.get(input.id()).unwrap();
let grad_weight = grads.get(weight.id()).unwrap();
let grad_bias = grads.get(bias.id()).unwrap();
let gi: Vec<f32> = grad_input.to_vec();
let gw: Vec<f32> = grad_weight.to_vec();
let gb: Vec<f32> = grad_bias.to_vec();
assert_eq!(gi.len(), 3);
assert_eq!(gw.len(), 3);
assert_eq!(gb.len(), 3);
for val in &gb {
assert!(
(*val - 1.0).abs() < 1e-5,
"bias gradient should be 1.0, got {}",
val,
);
}
let sum: f32 = gi.iter().sum();
assert!(
sum.abs() < 1e-5,
"sum of input gradients should be ~0, got {}",
sum,
);
for val in &gi {
assert!(val.is_finite());
}
for val in &gw {
assert!(val.is_finite());
}
}
#[test]
fn test_var_rms_norm_no_grad() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let input = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[1, 2], &device),
false,
);
let weight = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0], &[2], &device),
false,
);
let result = var_rms_norm(&input, &weight, 1e-5, &client).unwrap();
assert!(!result.requires_grad());
assert!(result.grad_fn().is_none());
}
#[test]
fn test_var_group_norm_forward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let input = Var::new(
Tensor::<CpuRuntime>::from_slice(
&[
1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
&[1, 4, 3],
&device,
),
false,
);
let weight = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[4], &device),
false,
);
let bias = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.0f32, 0.0, 0.0, 0.0], &[4], &device),
false,
);
let result = var_group_norm(&input, &weight, &bias, 2, 1e-5, &client).unwrap();
assert_eq!(result.tensor().shape(), &[1, 4, 3]);
let data: Vec<f32> = result.tensor().to_vec();
let group0_sum: f32 = data[0..6].iter().sum();
assert!(
group0_sum.abs() < 1e-4,
"group 0 mean should be ~0, sum={group0_sum}"
);
let group1_sum: f32 = data[6..12].iter().sum();
assert!(
group1_sum.abs() < 1e-4,
"group 1 mean should be ~0, sum={group1_sum}"
);
}
#[test]
fn test_var_group_norm_backward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let input = Var::new(
Tensor::<CpuRuntime>::from_slice(
&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
&[1, 4, 2],
&device,
),
true,
);
let weight = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[4], &device),
true,
);
let bias = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.0f32, 0.0, 0.0, 0.0], &[4], &device),
true,
);
let output = var_group_norm(&input, &weight, &bias, 2, 1e-5, &client).unwrap();
let loss = crate::autograd::var_sum(&output, &[], false, &client).unwrap();
let grads = backward(&loss, &client).unwrap();
let d_input: Vec<f32> = grads.get(input.id()).unwrap().to_vec();
let d_weight: Vec<f32> = grads.get(weight.id()).unwrap().to_vec();
let d_bias: Vec<f32> = grads.get(bias.id()).unwrap().to_vec();
assert_eq!(d_input.len(), 8);
assert_eq!(d_weight.len(), 4);
assert_eq!(d_bias.len(), 4);
for &b in &d_bias {
assert!((b - 2.0).abs() < 1e-5, "d_bias should be 2.0, got {b}");
}
for v in d_input.iter().chain(d_weight.iter()) {
assert!(v.is_finite());
}
}
#[test]
fn test_var_fused_add_rms_norm_forward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[1, 4], &device),
true,
);
let residual = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.1f32, 0.2, 0.3, 0.4], &[1, 4], &device),
true,
);
let weight = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[4], &device),
true,
);
let result = var_fused_add_rms_norm(&x, &residual, &weight, 1e-5, &client).unwrap();
let data: Vec<f32> = result.tensor().to_vec();
assert_eq!(data.len(), 4);
for val in &data {
assert!(val.is_finite());
}
}
#[test]
fn test_var_fused_add_rms_norm_backward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[1, 3], &device),
true,
);
let residual = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.1f32, 0.2, 0.3], &[1, 3], &device),
true,
);
let weight = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0], &[3], &device),
true,
);
let output = var_fused_add_rms_norm(&x, &residual, &weight, 1e-5, &client).unwrap();
let loss = crate::autograd::var_sum(&output, &[0, 1], false, &client).unwrap();
let grads = backward(&loss, &client).unwrap();
let gx: Vec<f32> = grads.get(x.id()).unwrap().to_vec();
let gr: Vec<f32> = grads.get(residual.id()).unwrap().to_vec();
let gw: Vec<f32> = grads.get(weight.id()).unwrap().to_vec();
assert_eq!(gx.len(), 3);
assert_eq!(gr.len(), 3);
assert_eq!(gw.len(), 3);
for (a, b) in gx.iter().zip(gr.iter()) {
assert!(
(a - b).abs() < 1e-5,
"x and residual grads must match: {a} vs {b}"
);
}
for val in gx.iter().chain(gw.iter()) {
assert!(val.is_finite());
}
}
#[test]
fn test_var_fused_add_rms_norm_no_grad() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[1, 2], &device),
false,
);
let residual = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.1f32, 0.2], &[1, 2], &device),
false,
);
let weight = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0], &[2], &device),
false,
);
let result = var_fused_add_rms_norm(&x, &residual, &weight, 1e-5, &client).unwrap();
assert!(!result.requires_grad());
}
#[test]
fn test_var_fused_add_layer_norm_forward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[1, 4], &device),
true,
);
let residual = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.1f32, 0.2, 0.3, 0.4], &[1, 4], &device),
true,
);
let weight = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0, 1.0], &[4], &device),
true,
);
let bias = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.0f32, 0.0, 0.0, 0.0], &[4], &device),
true,
);
let result =
var_fused_add_layer_norm(&x, &residual, &weight, &bias, 1e-5, &client).unwrap();
let data: Vec<f32> = result.tensor().to_vec();
let sum: f32 = data.iter().sum();
assert!(
sum.abs() < 1e-4,
"output should have ~0 mean, got sum={sum}"
);
}
#[test]
fn test_var_fused_add_layer_norm_backward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[1, 3], &device),
true,
);
let residual = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.1f32, 0.2, 0.3], &[1, 3], &device),
true,
);
let weight = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 1.0, 1.0], &[3], &device),
true,
);
let bias = Var::new(
Tensor::<CpuRuntime>::from_slice(&[0.0f32, 0.0, 0.0], &[3], &device),
true,
);
let output =
var_fused_add_layer_norm(&x, &residual, &weight, &bias, 1e-5, &client).unwrap();
let loss = crate::autograd::var_sum(&output, &[0, 1], false, &client).unwrap();
let grads = backward(&loss, &client).unwrap();
let gx: Vec<f32> = grads.get(x.id()).unwrap().to_vec();
let gr: Vec<f32> = grads.get(residual.id()).unwrap().to_vec();
let gw: Vec<f32> = grads.get(weight.id()).unwrap().to_vec();
let gb: Vec<f32> = grads.get(bias.id()).unwrap().to_vec();
for (a, b) in gx.iter().zip(gr.iter()) {
assert!((a - b).abs() < 1e-5, "x and residual grads must match");
}
for val in &gb {
assert!(
(*val - 1.0).abs() < 1e-5,
"bias gradient should be 1.0, got {val}"
);
}
for val in gx.iter().chain(gw.iter()) {
assert!(val.is_finite());
}
}
}