use crate::{
array::Array,
error::{
ArithmeticOverflowPayload, DivisibilityConstraintPayload, InvariantViolationPayload,
LengthMismatchPayload, OutOfRangePayload, RankMismatchPayload, Result, check,
},
ops,
stream::default_stream,
};
use smol_str::format_smolstr;
#[inline]
fn null_array() -> Array {
Array(unsafe { mlxrs_sys::mlx_array_new() })
}
fn fast_layer_norm(
x: &Array,
weight: Option<&Array>,
bias: Option<&Array>,
eps: f32,
) -> Result<Array> {
let null_w = null_array();
let null_b = null_array();
let w = weight.unwrap_or(&null_w);
let b = bias.unwrap_or(&null_b);
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_fast_layer_norm(&mut out.0, x.0, w.0, b.0, eps, default_stream())
})?;
Ok(out)
}
#[derive(Debug)]
pub struct RMSNorm {
weight: Array,
pub eps: f32,
}
impl RMSNorm {
pub fn new(weight: Array, eps: f32) -> Self {
Self { weight, eps }
}
#[inline(always)]
pub fn weight_ref(&self) -> &Array {
&self.weight
}
pub fn forward(&self, x: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_fast_rms_norm(&mut out.0, x.0, self.weight.0, self.eps, default_stream())
})?;
Ok(out)
}
}
#[derive(Debug)]
pub struct LayerNorm {
weight: Option<Array>,
bias: Option<Array>,
pub eps: f32,
}
impl LayerNorm {
pub fn new(weight: Option<Array>, bias: Option<Array>, eps: f32) -> Self {
Self { weight, bias, eps }
}
#[inline(always)]
pub fn weight_ref(&self) -> Option<&Array> {
self.weight.as_ref()
}
#[inline(always)]
pub fn bias_ref(&self) -> Option<&Array> {
self.bias.as_ref()
}
pub fn forward(&self, x: &Array) -> Result<Array> {
fast_layer_norm(x, self.weight.as_ref(), self.bias.as_ref(), self.eps)
}
}
#[derive(Debug)]
pub struct GroupNorm {
num_groups: i32,
dims: i32,
affine: Option<(Array, Array)>,
pub eps: f32,
pub pytorch_compatible: bool,
}
fn validate_group_params(num_groups: i32, dims: i32) -> Result<()> {
if num_groups <= 0 {
return Err(crate::error::Error::OutOfRange(OutOfRangePayload::new(
"GroupNorm: num_groups",
"must be positive (> 0)",
format_smolstr!("{num_groups}"),
)));
}
if dims <= 0 {
return Err(crate::error::Error::OutOfRange(OutOfRangePayload::new(
"GroupNorm: dims",
"must be positive (> 0)",
format_smolstr!("{dims}"),
)));
}
if dims % num_groups != 0 {
return Err(crate::error::Error::DivisibilityConstraint(
DivisibilityConstraintPayload::new(
"GroupNorm",
"dims",
dims as u64,
"num_groups",
num_groups as u64,
),
));
}
Ok(())
}
impl GroupNorm {
pub fn new(
num_groups: i32,
dims: i32,
eps: f32,
affine: bool,
pytorch_compatible: bool,
) -> Result<Self> {
validate_group_params(num_groups, dims)?;
let affine = if affine {
let d = usize::try_from(dims).expect("dims > 0 guarded by validate_group_params");
Some((Array::ones::<f32>(&(d,))?, Array::zeros::<f32>(&(d,))?))
} else {
None
};
Self::with_affine(num_groups, dims, eps, affine, pytorch_compatible)
}
pub fn with_affine(
num_groups: i32,
dims: i32,
eps: f32,
affine: Option<(Array, Array)>,
pytorch_compatible: bool,
) -> Result<Self> {
validate_group_params(num_groups, dims)?;
if let Some((weight, bias)) = &affine {
let dims_usize = dims as usize;
let w_shape = weight.shape();
if w_shape.len() != 1 {
return Err(crate::error::Error::RankMismatch(RankMismatchPayload::new(
"GroupNorm: affine weight must be rank-1 [dims]",
w_shape.len() as u32,
w_shape.to_vec(),
)));
}
if w_shape[0] != dims_usize {
return Err(crate::error::Error::LengthMismatch(
LengthMismatchPayload::new(
"GroupNorm: affine weight length must equal dims",
dims_usize,
w_shape[0],
),
));
}
let b_shape = bias.shape();
if b_shape.len() != 1 {
return Err(crate::error::Error::RankMismatch(RankMismatchPayload::new(
"GroupNorm: affine bias must be rank-1 [dims]",
b_shape.len() as u32,
b_shape.to_vec(),
)));
}
if b_shape[0] != dims_usize {
return Err(crate::error::Error::LengthMismatch(
LengthMismatchPayload::new(
"GroupNorm: affine bias length must equal dims",
dims_usize,
b_shape[0],
),
));
}
}
Ok(Self {
num_groups,
dims,
affine,
eps,
pytorch_compatible,
})
}
pub fn num_groups(&self) -> i32 {
self.num_groups
}
pub fn dims(&self) -> i32 {
self.dims
}
pub fn affine(&self) -> Option<(&Array, &Array)> {
self.affine.as_ref().map(|(w, b)| (w, b))
}
pub fn forward(&self, x: &Array) -> Result<Array> {
let normalized = if self.pytorch_compatible {
self.pytorch_group_norm(x)?
} else {
self.group_norm(x)?
};
match &self.affine {
Some((w, b)) => {
let scaled = ops::arithmetic::multiply(w, &normalized)?;
ops::arithmetic::add(&scaled, b)
}
None => Ok(normalized),
}
}
fn validate_input_shape(&self, orig_shape: &[usize]) -> Result<i32> {
if orig_shape.len() < 2 {
return Err(crate::error::Error::RankMismatch(RankMismatchPayload::new(
"GroupNorm input must have rank >= 2 (at least [batch, dims])",
orig_shape.len() as u32,
orig_shape.to_vec(),
)));
}
let dims = *orig_shape
.last()
.expect("rank-≥-2 guarded above ⇒ last() is Some");
let dims_i32 = i32::try_from(dims).map_err(|_| {
crate::error::Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"GroupNorm: feature dim exceeds i32::MAX",
"i32",
[("dim", dims as u64)],
))
})?;
if dims_i32 != self.dims {
return Err(crate::error::Error::LengthMismatch(
LengthMismatchPayload::new(
"GroupNorm: input last-axis must match configured dims",
self.dims as usize,
dims_i32 as usize,
),
));
}
if dims_i32 % self.num_groups != 0 {
return Err(crate::error::Error::DivisibilityConstraint(
DivisibilityConstraintPayload::new(
"GroupNorm",
"feature_dim",
dims_i32 as u64,
"num_groups",
self.num_groups as u64,
),
));
}
Ok(dims_i32)
}
fn group_norm(&self, x: &Array) -> Result<Array> {
let orig_shape = x.shape();
self.validate_input_shape(&orig_shape)?;
let batch = batch_dim(&orig_shape)?;
let inferred = inferred_dim(&orig_shape, &[batch, self.num_groups])?;
let three_d: &[i32] = &[batch, inferred, self.num_groups];
let reshaped = ops::shape::reshape(x, &three_d)?;
let means = ops::reduction::mean_axes(&reshaped, &[1], true)?;
let var = ops::reduction::var_axes(&reshaped, &[1], true, 0)?;
let eps_like = scalar_like(self.eps, &var)?;
let denom = ops::arithmetic::rsqrt(&ops::arithmetic::add(&var, &eps_like)?)?;
let centered = ops::arithmetic::subtract(&reshaped, &means)?;
let normalized = ops::arithmetic::multiply(¢ered, &denom)?;
let orig_i32 = shape_to_i32(&orig_shape)?;
let orig_slice: &[i32] = &orig_i32;
ops::shape::reshape(&normalized, &orig_slice)
}
fn pytorch_group_norm(&self, x: &Array) -> Result<Array> {
let orig_shape = x.shape();
let dims_i32 = self.validate_input_shape(&orig_shape)?;
let batch = batch_dim(&orig_shape)?;
let group_size = dims_i32 / self.num_groups;
let mid = inferred_dim(&orig_shape, &[batch, self.num_groups, group_size])?;
let four_d: &[i32] = &[batch, mid, self.num_groups, group_size];
let x = ops::shape::reshape(x, &four_d)?;
let x = ops::shape::transpose_axes(&x, &[0, 2, 1, 3])?;
let collapsed = mid.checked_mul(group_size).ok_or_else(|| {
crate::error::Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"GroupNorm: mid * group_size",
"i32",
[("mid", mid as u64), ("group_size", group_size as u64)],
))
})?;
let three_d: &[i32] = &[batch, self.num_groups, collapsed];
let x = ops::shape::reshape(&x, &three_d)?;
let x = fast_layer_norm(&x, None, None, self.eps)?;
let four_d_back: &[i32] = &[batch, self.num_groups, mid, group_size];
let x = ops::shape::reshape(&x, &four_d_back)?;
let x = ops::shape::transpose_axes(&x, &[0, 2, 1, 3])?;
let orig_i32 = shape_to_i32(&orig_shape)?;
let orig_slice: &[i32] = &orig_i32;
ops::shape::reshape(&x, &orig_slice)
}
}
fn shape_to_i32(shape: &[usize]) -> Result<Vec<i32>> {
shape
.iter()
.map(|&d| {
i32::try_from(d).map_err(|_| {
crate::error::Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"shape_to_i32: dim exceeds i32::MAX",
"i32",
[("dim", d as u64)],
))
})
})
.collect()
}
fn batch_dim(shape: &[usize]) -> Result<i32> {
let b = *shape.first().ok_or_else(|| {
crate::error::Error::RankMismatch(RankMismatchPayload::new(
"GroupNorm input must have rank >= 1 (the batch axis)",
0,
shape.to_vec(),
))
})?;
i32::try_from(b).map_err(|_| {
crate::error::Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"GroupNorm: batch dim exceeds i32::MAX",
"i32",
[("batch_dim", b as u64)],
))
})
}
fn inferred_dim(shape: &[usize], known_dims: &[i32]) -> Result<i32> {
let total: usize = shape
.iter()
.enumerate()
.try_fold(1usize, |acc, (idx, &dim)| {
acc.checked_mul(dim).ok_or_else(|| {
crate::error::Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"GroupNorm: shape product overflows usize",
"usize",
[
("acc", acc as u64),
("dim", dim as u64),
("dim_index", idx as u64),
],
))
})
})?;
let mut divisor: usize = 1;
for &d in known_dims {
let du = usize::try_from(d).map_err(|_| {
crate::error::Error::OutOfRange(OutOfRangePayload::new(
"GroupNorm: known reshape dim",
"must be non-negative",
format_smolstr!("{d}"),
))
})?;
divisor = divisor.checked_mul(du).ok_or_else(|| {
crate::error::Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"GroupNorm: reshape divisor product",
"usize",
[("divisor", divisor as u64), ("factor", du as u64)],
))
})?;
}
if divisor == 0 {
return Err(crate::error::Error::InvariantViolation(
InvariantViolationPayload::new(
"GroupNorm: inferred_dim reshape divisor",
"must be non-zero (one of the known_dims was 0)",
),
));
}
if !total.is_multiple_of(divisor) {
return Err(crate::error::Error::DivisibilityConstraint(
DivisibilityConstraintPayload::new(
"GroupNorm: cannot reshape elements into a layout",
"total_elements",
total as u64,
"divisor_per_slot",
divisor as u64,
),
));
}
let inferred = total / divisor;
i32::try_from(inferred).map_err(|_| {
crate::error::Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"GroupNorm: inferred dim exceeds i32::MAX",
"i32",
[("inferred_dim", inferred as u64)],
))
})
}
fn scalar_like(value: f32, like: &Array) -> Result<Array> {
crate::error::ensure_handler_installed();
ops::misc::astype(&Array::full::<f32>(&(1,), value)?, like.dtype()?)
}
#[cfg(test)]
mod tests;