use crate::error::{Error, Result};
use numr::autograd::{Var, var_group_norm};
use numr::ops::{NormalizationOps, ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub struct GroupNorm<R: Runtime> {
weight: Var<R>,
bias: Var<R>,
num_groups: usize,
eps: f32,
}
impl<R: Runtime> GroupNorm<R> {
pub fn new(
weight: Tensor<R>,
bias: Tensor<R>,
num_groups: usize,
eps: f32,
trainable: bool,
) -> Self {
Self {
weight: Var::new(weight, trainable),
bias: Var::new(bias, trainable),
num_groups,
eps,
}
}
pub fn forward<C>(&self, client: &C, input: &Var<R>) -> Result<Var<R>>
where
R: Runtime,
C: RuntimeClient<R> + NormalizationOps<R>,
R::Client: TensorOps<R> + ScalarOps<R>,
{
var_group_norm(
input,
&self.weight,
&self.bias,
self.num_groups,
self.eps,
client,
)
.map_err(Error::Numr)
}
pub fn weight(&self) -> &Var<R> {
&self.weight
}
pub fn bias(&self) -> &Var<R> {
&self.bias
}
pub fn num_groups(&self) -> usize {
self.num_groups
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::cpu_setup;
use numr::autograd::backward;
use numr::runtime::cpu::CpuRuntime;
#[test]
fn test_groupnorm_output_shape() {
let (client, device) = cpu_setup();
let weight = Tensor::<CpuRuntime>::from_slice(&[1.0f32; 4], &[4], &device);
let bias = Tensor::<CpuRuntime>::from_slice(&[0.0f32; 4], &[4], &device);
let norm = GroupNorm::new(weight, bias, 2, 1e-5, false);
let input = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32; 24], &[2, 4, 3], &device),
false,
);
let out = norm.forward(&client, &input).unwrap();
assert_eq!(out.tensor().shape(), &[2, 4, 3]);
}
#[test]
fn test_groupnorm_zero_mean_per_group() {
let (client, device) = cpu_setup();
let weight = Tensor::<CpuRuntime>::from_slice(&[1.0f32; 4], &[4], &device);
let bias = Tensor::<CpuRuntime>::from_slice(&[0.0f32; 4], &[4], &device);
let norm = GroupNorm::new(weight, bias, 2, 1e-5, false);
let input = Var::new(
Tensor::<CpuRuntime>::from_slice(
&[
1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
&[1, 4, 3],
&device,
),
false,
);
let out = norm.forward(&client, &input).unwrap();
let data: Vec<f32> = out.tensor().to_vec();
let g0: f32 = data[0..6].iter().sum();
assert!(g0.abs() < 1e-4, "group 0 mean should be ~0, sum={g0}");
}
#[test]
fn test_groupnorm_backward() {
let (client, device) = cpu_setup();
let weight = Tensor::<CpuRuntime>::from_slice(&[1.0f32; 4], &[4], &device);
let bias = Tensor::<CpuRuntime>::from_slice(&[0.0f32; 4], &[4], &device);
let norm = GroupNorm::new(weight, bias, 2, 1e-5, true);
let input = Var::new(
Tensor::<CpuRuntime>::from_slice(
&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
&[1, 4, 2],
&device,
),
true,
);
let out = norm.forward(&client, &input).unwrap();
let loss = numr::autograd::var_sum(&out, &[], false, &client).unwrap();
let grads = backward(&loss, &client).unwrap();
let d_input: Vec<f32> = grads.get(input.id()).unwrap().to_vec();
assert_eq!(d_input.len(), 8);
for v in &d_input {
assert!(v.is_finite());
}
}
}