use crate::error::Result;
use pmetal_bridge::compat::{Array, Param, ops, random};
#[derive(Debug)]
pub struct WeightNormConv1d {
pub weight_v: Param<Array>,
pub weight_g: Param<Array>,
pub bias: Option<Param<Array>>,
pub in_channels: i32,
pub out_channels: i32,
pub kernel_size: i32,
pub stride: i32,
pub padding: i32,
pub dilation: i32,
pub groups: i32,
}
impl WeightNormConv1d {
#[allow(clippy::too_many_arguments)]
pub fn new(
in_channels: i32,
out_channels: i32,
kernel_size: i32,
stride: Option<i32>,
padding: Option<i32>,
dilation: Option<i32>,
groups: Option<i32>,
bias: Option<bool>,
) -> Result<Self> {
let stride = stride.unwrap_or(1);
let padding = padding.unwrap_or(0);
let dilation = dilation.unwrap_or(1);
let groups = groups.unwrap_or(1);
let use_bias = bias.unwrap_or(true);
let fan_in = (in_channels / groups) * kernel_size;
let bound = (1.0 / fan_in as f32).sqrt();
let u = random::uniform(
&[out_channels, in_channels / groups, kernel_size],
pmetal_bridge::compat::Dtype::Float32,
);
let scale = Array::from_f32(2.0 * bound);
let offset = Array::from_f32(-bound);
let weight_v = u.multiply(&scale).add(&offset);
let norm = weight_norm(&weight_v)?;
let weight_g = norm;
let bias = if use_bias {
Some(Param::new(Array::zeros(&[out_channels], 10)))
} else {
None
};
Ok(Self {
weight_v: Param::new(weight_v),
weight_g: Param::new(weight_g),
bias,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
})
}
fn compute_weight(&self) -> Result<Array> {
let v = &self.weight_v.value;
let g = &self.weight_g.value;
let norm = weight_norm(v)?;
let v_normalized = v.divide(&norm);
Ok(v_normalized.multiply(g))
}
pub fn forward(&self, x: &Array) -> Result<Array> {
let weight = self.compute_weight()?;
let x_nlc = x.transpose_axes(&[0, 2, 1]);
let weight_oki = weight.transpose_axes(&[0, 2, 1]);
let output = ops::conv1d(
&x_nlc,
&weight_oki,
self.stride,
self.padding,
self.dilation,
self.groups,
);
let output = output.transpose_axes(&[0, 2, 1]);
if let Some(bias) = &self.bias {
let bias_reshaped = bias.value.reshape(&[1, self.out_channels, 1]);
Ok(output.add(&bias_reshaped))
} else {
Ok(output)
}
}
}
#[derive(Debug)]
pub struct WeightNormConvTranspose1d {
pub weight_v: Param<Array>,
pub weight_g: Param<Array>,
pub bias: Option<Param<Array>>,
pub in_channels: i32,
pub out_channels: i32,
pub kernel_size: i32,
pub stride: i32,
pub padding: i32,
pub output_padding: i32,
pub dilation: i32,
pub groups: i32,
}
impl WeightNormConvTranspose1d {
#[allow(clippy::too_many_arguments)]
pub fn new(
in_channels: i32,
out_channels: i32,
kernel_size: i32,
stride: Option<i32>,
padding: Option<i32>,
output_padding: Option<i32>,
dilation: Option<i32>,
groups: Option<i32>,
bias: Option<bool>,
) -> Result<Self> {
let stride = stride.unwrap_or(1);
let padding = padding.unwrap_or(0);
let output_padding = output_padding.unwrap_or(0);
let dilation = dilation.unwrap_or(1);
let groups = groups.unwrap_or(1);
let use_bias = bias.unwrap_or(true);
let fan_in = in_channels * kernel_size;
let bound = (1.0 / fan_in as f32).sqrt();
let u = random::uniform(
&[in_channels, out_channels / groups, kernel_size],
pmetal_bridge::compat::Dtype::Float32,
);
let scale = Array::from_f32(2.0 * bound);
let offset = Array::from_f32(-bound);
let weight_v = u.multiply(&scale).add(&offset);
let norm = weight_norm(&weight_v)?;
let weight_g = norm;
let bias = if use_bias {
Some(Param::new(Array::zeros(&[out_channels], 10)))
} else {
None
};
Ok(Self {
weight_v: Param::new(weight_v),
weight_g: Param::new(weight_g),
bias,
in_channels,
out_channels,
kernel_size,
stride,
padding,
output_padding,
dilation,
groups,
})
}
fn compute_weight(&self) -> Result<Array> {
let v = &self.weight_v.value;
let g = &self.weight_g.value;
let norm = weight_norm(v)?;
let v_normalized = v.divide(&norm);
Ok(v_normalized.multiply(g))
}
pub fn forward(&self, x: &Array) -> Result<Array> {
let weight = self.compute_weight()?;
let input_length = x.dim(2);
let _output_length = (input_length - 1) * self.stride - 2 * self.padding
+ self.dilation * (self.kernel_size - 1)
+ self.output_padding
+ 1;
let weight_transposed = weight.transpose_axes(&[1, 0, 2]);
let output = conv_transpose_1d_manual(
x,
&weight_transposed,
self.stride,
self.padding,
self.output_padding,
self.dilation,
self.groups,
)?;
if let Some(bias) = &self.bias {
let bias_reshaped = bias.value.reshape(&[1, self.out_channels, 1]);
Ok(output.add(&bias_reshaped))
} else {
Ok(output)
}
}
}
fn weight_norm(weight: &Array) -> Result<Array> {
let sq = weight.multiply(weight);
let sum_sq = sq.sum_axes(&[1, 2], true);
let norm = sum_sq.sqrt();
let eps = Array::from_f32(1e-12);
Ok(norm.add(&eps))
}
fn flip_axis(arr: &Array, axis: i32) -> Result<Array> {
let axis_len = arr.dim(axis);
let indices: Vec<i32> = (0..axis_len).rev().collect();
let indices_arr = Array::from_i32_slice(&indices);
Ok(arr.take_axis(&indices_arr, axis))
}
fn conv_transpose_1d_manual(
x: &Array,
weight: &Array, stride: i32,
padding: i32,
output_padding: i32,
dilation: i32,
groups: i32,
) -> Result<Array> {
let batch = x.dim(0);
let in_channels = x.dim(1);
let in_length = x.dim(2);
let out_channels = weight.dim(0);
let kernel_size = weight.dim(2);
let run_conv1d = |input: &Array, w: &Array, s: i32, p: i32, d: i32, g: i32| -> Array {
let input_nlc = input.transpose_axes(&[0, 2, 1]);
let weight_oki = w.transpose_axes(&[0, 2, 1]);
let output_nlc = ops::conv1d(&input_nlc, &weight_oki, s, p, d, g);
output_nlc.transpose_axes(&[0, 2, 1])
};
if stride == 1 && padding == 0 && output_padding == 0 && dilation == 1 {
let weight_flipped = flip_axis(weight, 2)?; return Ok(run_conv1d(
x,
&weight_flipped,
1,
kernel_size - 1,
1,
groups,
));
}
let upsampled_length = (in_length - 1) * stride + 1;
if stride > 1 {
let zeros_between = Array::zeros(&[batch, in_channels, in_length, stride - 1], 10);
let x_expanded = x.reshape(&[batch, in_channels, in_length, 1]);
let interleaved = ops::concatenate_axis(&[&x_expanded, &zeros_between], -1);
let interleaved = interleaved.reshape(&[batch, in_channels, in_length * stride]);
let upsampled = interleaved.slice(&[0, 0, 0], &[batch, in_channels, upsampled_length]);
let weight_flipped = flip_axis(weight, 2)?;
let conv_padding = dilation * (kernel_size - 1) - padding;
let conv_padding = conv_padding.max(0);
let output = run_conv1d(
&upsampled,
&weight_flipped,
1,
conv_padding,
dilation,
groups,
);
if output_padding > 0 {
let pad = Array::zeros(&[batch, out_channels, output_padding], 10);
return Ok(ops::concatenate_axis(&[&output, &pad], -1));
}
Ok(output)
} else {
let weight_flipped = flip_axis(weight, 2)?;
let conv_padding = dilation * (kernel_size - 1) - padding;
let conv_padding = conv_padding.max(0);
Ok(run_conv1d(
x,
&weight_flipped,
1,
conv_padding,
dilation,
groups,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_weight_norm_conv1d_shape() {
let conv =
WeightNormConv1d::new(4, 8, 3, Some(1), Some(1), None, None, Some(true)).unwrap();
let x = Array::random_normal(&[2, 4, 16], 10);
let y = conv.forward(&x).unwrap();
let y2 = y.clone();
y2.eval();
assert_eq!(y2.shape(), &[2, 8, 16]);
}
#[test]
fn test_weight_norm_conv1d_no_bias() {
let conv =
WeightNormConv1d::new(4, 8, 3, Some(1), Some(1), None, None, Some(false)).unwrap();
let x = Array::random_normal(&[2, 4, 16], 10);
let y = conv.forward(&x).unwrap();
let y2 = y.clone();
y2.eval();
assert_eq!(y2.shape(), &[2, 8, 16]);
assert!(conv.bias.is_none());
}
#[test]
fn test_weight_norm_values() {
let conv = WeightNormConv1d::new(2, 4, 3, None, None, None, None, None).unwrap();
let weight = conv.compute_weight().unwrap();
let w2 = weight.clone();
w2.eval();
assert_eq!(w2.shape(), &[4, 2, 3]);
}
#[test]
fn test_conv_transpose1d_shape() {
let conv =
WeightNormConvTranspose1d::new(8, 4, 4, Some(2), Some(1), None, None, None, Some(true))
.unwrap();
let x = Array::random_normal(&[1, 8, 16], 10);
let y = conv.forward(&x).unwrap();
let y2 = y.clone();
y2.eval();
assert_eq!(y2.shape(), &[1, 4, 32]);
}
#[test]
fn test_conv_transpose1d_upsample_4x() {
let conv = WeightNormConvTranspose1d::new(
512,
256,
16,
Some(4),
Some(6),
None,
None,
None,
Some(true),
)
.unwrap();
let x = Array::random_normal(&[1, 512, 8], 10);
let y = conv.forward(&x).unwrap();
let y2 = y.clone();
y2.eval();
assert_eq!(y2.shape(), &[1, 256, 32]);
}
}