use crate::error::Result;
use mlx_rs::Array;
use mlx_rs::module::Param;
#[derive(Debug)]
pub struct WeightNormConv1d {
pub weight_v: Param<Array>,
pub weight_g: Param<Array>,
pub bias: Option<Param<Array>>,
pub in_channels: i32,
pub out_channels: i32,
pub kernel_size: i32,
pub stride: i32,
pub padding: i32,
pub dilation: i32,
pub groups: i32,
}
impl WeightNormConv1d {
#[allow(clippy::too_many_arguments)]
pub fn new(
in_channels: i32,
out_channels: i32,
kernel_size: i32,
stride: Option<i32>,
padding: Option<i32>,
dilation: Option<i32>,
groups: Option<i32>,
bias: Option<bool>,
) -> Result<Self> {
let stride = stride.unwrap_or(1);
let padding = padding.unwrap_or(0);
let dilation = dilation.unwrap_or(1);
let groups = groups.unwrap_or(1);
let use_bias = bias.unwrap_or(true);
let fan_in = (in_channels / groups) * kernel_size;
let bound = (1.0 / fan_in as f32).sqrt();
let weight_v = mlx_rs::random::uniform::<_, f32>(
-bound,
bound,
&[out_channels, in_channels / groups, kernel_size],
None,
)?;
let norm = weight_norm(&weight_v)?;
let weight_g = norm;
let bias = if use_bias {
Some(Param::new(mlx_rs::ops::zeros::<f32>(&[out_channels])?))
} else {
None
};
Ok(Self {
weight_v: Param::new(weight_v),
weight_g: Param::new(weight_g),
bias,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
})
}
fn compute_weight(&self) -> Result<Array> {
let v = self.weight_v.as_ref();
let g = self.weight_g.as_ref();
let norm = weight_norm(v)?;
let v_normalized = v.divide(&norm)?;
Ok(v_normalized.multiply(g)?)
}
pub fn forward(&self, x: &Array) -> Result<Array> {
let weight = self.compute_weight()?;
let x_nlc = x.transpose_axes(&[0, 2, 1])?;
let weight_oki = weight.transpose_axes(&[0, 2, 1])?;
let output = mlx_rs::ops::conv1d(
&x_nlc,
&weight_oki,
self.stride,
self.padding,
self.dilation,
self.groups,
)?;
let output = output.transpose_axes(&[0, 2, 1])?;
if let Some(bias) = &self.bias {
let bias_reshaped = bias.as_ref().reshape(&[1, self.out_channels, 1])?;
Ok(output.add(&bias_reshaped)?)
} else {
Ok(output)
}
}
}
#[derive(Debug)]
pub struct WeightNormConvTranspose1d {
pub weight_v: Param<Array>,
pub weight_g: Param<Array>,
pub bias: Option<Param<Array>>,
pub in_channels: i32,
pub out_channels: i32,
pub kernel_size: i32,
pub stride: i32,
pub padding: i32,
pub output_padding: i32,
pub dilation: i32,
pub groups: i32,
}
impl WeightNormConvTranspose1d {
#[allow(clippy::too_many_arguments)]
pub fn new(
in_channels: i32,
out_channels: i32,
kernel_size: i32,
stride: Option<i32>,
padding: Option<i32>,
output_padding: Option<i32>,
dilation: Option<i32>,
groups: Option<i32>,
bias: Option<bool>,
) -> Result<Self> {
let stride = stride.unwrap_or(1);
let padding = padding.unwrap_or(0);
let output_padding = output_padding.unwrap_or(0);
let dilation = dilation.unwrap_or(1);
let groups = groups.unwrap_or(1);
let use_bias = bias.unwrap_or(true);
let fan_in = in_channels * kernel_size;
let bound = (1.0 / fan_in as f32).sqrt();
let weight_v = mlx_rs::random::uniform::<_, f32>(
-bound,
bound,
&[in_channels, out_channels / groups, kernel_size],
None,
)?;
let norm = weight_norm(&weight_v)?;
let weight_g = norm;
let bias = if use_bias {
Some(Param::new(mlx_rs::ops::zeros::<f32>(&[out_channels])?))
} else {
None
};
Ok(Self {
weight_v: Param::new(weight_v),
weight_g: Param::new(weight_g),
bias,
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
dilation,
groups,
})
}
fn compute_weight(&self) -> Result<Array> {
let v = self.weight_v.as_ref();
let g = self.weight_g.as_ref();
let norm = weight_norm(v)?;
let v_normalized = v.divide(&norm)?;
Ok(v_normalized.multiply(g)?)
}
pub fn forward(&self, x: &Array) -> Result<Array> {
let weight = self.compute_weight()?;
let input_length = x.dim(2);
let _output_length = (input_length - 1) * self.stride - 2 * self.padding
+ self.dilation * (self.kernel_size - 1)
+ self.output_padding
+ 1;
let weight_transposed = weight.transpose_axes(&[1, 0, 2])?;
let output = conv_transpose_1d_manual(
x,
&weight_transposed,
self.stride,
self.padding,
self.output_padding,
self.dilation,
self.groups,
)?;
if let Some(bias) = &self.bias {
let bias_reshaped = bias.as_ref().reshape(&[1, self.out_channels, 1])?;
Ok(output.add(&bias_reshaped)?)
} else {
Ok(output)
}
}
}
fn weight_norm(weight: &Array) -> Result<Array> {
let sq = weight.multiply(weight)?;
let sum_sq = sq.sum_axes(&[1, 2], Some(true))?;
let norm = sum_sq.sqrt()?;
let eps = Array::from_f32(1e-12);
Ok(norm.add(&eps)?)
}
fn flip_axis(arr: &Array, axis: i32) -> Result<Array> {
let axis_len = arr.dim(axis);
let indices: Vec<i32> = (0..axis_len).rev().collect();
let indices_arr = Array::from_slice(&indices, &[axis_len]);
arr.take_axis(&indices_arr, axis).map_err(Into::into)
}
fn conv_transpose_1d_manual(
x: &Array,
weight: &Array, stride: i32,
padding: i32,
output_padding: i32,
dilation: i32,
groups: i32,
) -> Result<Array> {
let batch = x.dim(0);
let in_channels = x.dim(1);
let in_length = x.dim(2);
let out_channels = weight.dim(0);
let kernel_size = weight.dim(2);
let run_conv1d = |input: &Array, w: &Array, s: i32, p: i32, d: i32, g: i32| -> Result<Array> {
let input_nlc = input.transpose_axes(&[0, 2, 1])?;
let weight_oki = w.transpose_axes(&[0, 2, 1])?;
let output_nlc = mlx_rs::ops::conv1d(&input_nlc, &weight_oki, s, p, d, g)?;
output_nlc.transpose_axes(&[0, 2, 1]).map_err(Into::into)
};
if stride == 1 && padding == 0 && output_padding == 0 && dilation == 1 {
let weight_flipped = flip_axis(weight, 2)?; return run_conv1d(x, &weight_flipped, 1, kernel_size - 1, 1, groups);
}
let upsampled_length = (in_length - 1) * stride + 1;
if stride > 1 {
use mlx_rs::ops::indexing::IndexOp;
let zeros_between =
mlx_rs::ops::zeros::<f32>(&[batch, in_channels, in_length, stride - 1])?;
let x_expanded = x.reshape(&[batch, in_channels, in_length, 1])?;
let interleaved = mlx_rs::ops::concatenate_axis(&[&x_expanded, &zeros_between], -1)?;
let interleaved = interleaved.reshape(&[batch, in_channels, in_length * stride])?;
let upsampled = interleaved.index((.., .., ..upsampled_length));
let weight_flipped = flip_axis(weight, 2)?;
let conv_padding = dilation * (kernel_size - 1) - padding;
let conv_padding = conv_padding.max(0);
let output = run_conv1d(
&upsampled,
&weight_flipped,
1,
conv_padding,
dilation,
groups,
)?;
if output_padding > 0 {
let pad = mlx_rs::ops::zeros::<f32>(&[batch, out_channels, output_padding])?;
return mlx_rs::ops::concatenate_axis(&[&output, &pad], -1).map_err(Into::into);
}
Ok(output)
} else {
let weight_flipped = flip_axis(weight, 2)?;
let conv_padding = dilation * (kernel_size - 1) - padding;
let conv_padding = conv_padding.max(0);
run_conv1d(x, &weight_flipped, 1, conv_padding, dilation, groups)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_weight_norm_conv1d_shape() {
let conv =
WeightNormConv1d::new(4, 8, 3, Some(1), Some(1), None, None, Some(true)).unwrap();
let x = mlx_rs::random::normal::<f32>(&[2, 4, 16], None, None, None).unwrap();
let y = conv.forward(&x).unwrap();
y.eval().unwrap();
assert_eq!(y.shape(), &[2, 8, 16]);
}
#[test]
fn test_weight_norm_conv1d_no_bias() {
let conv =
WeightNormConv1d::new(4, 8, 3, Some(1), Some(1), None, None, Some(false)).unwrap();
let x = mlx_rs::random::normal::<f32>(&[2, 4, 16], None, None, None).unwrap();
let y = conv.forward(&x).unwrap();
y.eval().unwrap();
assert_eq!(y.shape(), &[2, 8, 16]);
assert!(conv.bias.is_none());
}
#[test]
fn test_weight_norm_values() {
let conv = WeightNormConv1d::new(2, 4, 3, None, None, None, None, None).unwrap();
let weight = conv.compute_weight().unwrap();
weight.eval().unwrap();
assert_eq!(weight.shape(), &[4, 2, 3]);
}
#[test]
fn test_conv_transpose1d_shape() {
let conv =
WeightNormConvTranspose1d::new(8, 4, 4, Some(2), Some(1), None, None, None, Some(true))
.unwrap();
let x = mlx_rs::random::normal::<f32>(&[1, 8, 16], None, None, None).unwrap();
let y = conv.forward(&x).unwrap();
y.eval().unwrap();
assert_eq!(y.shape(), &[1, 4, 32]);
}
#[test]
fn test_conv_transpose1d_upsample_4x() {
let conv = WeightNormConvTranspose1d::new(
512,
256,
16,
Some(4),
Some(6),
None,
None,
None,
Some(true),
)
.unwrap();
let x = mlx_rs::random::normal::<f32>(&[1, 512, 8], None, None, None).unwrap();
let y = conv.forward(&x).unwrap();
y.eval().unwrap();
assert_eq!(y.shape(), &[1, 256, 32]);
}
}