maolan-generate 0.0.8

Generate music using Maolan and HeartMuLa
Documentation
//! Weight loading utilities for HeartCodec
//!
//! Handles PyTorch weight normalization by computing actual weights at load time:
//! weight = original0 * (original1 / norm(original1))

use burn::nn::PaddingConfig1d;
use burn::nn::conv::{Conv1d, Conv1dConfig};
use burn::prelude::Backend;
use burn::tensor::Tensor;

/// Load weight-normalized Conv1d weights into a standard Conv1d
///
/// The burnpack stores:
/// - `parametrizations.weight.original0`: [out_ch, 1, 1] - magnitude
/// - `parametrizations.weight.original1`: [out_ch, in_ch/groups, kernel] - direction
#[allow(clippy::too_many_arguments)]
pub fn load_weight_norm_conv1d<B: Backend>(
    device: &B::Device,
    in_channels: usize,
    out_channels: usize,
    kernel_size: usize,
    stride: usize,
    padding: PaddingConfig1d,
    dilation: usize,
    groups: usize,
    _bias_data: Option<Vec<f32>>,
    original0: Option<Vec<f32>>,
    original1: Option<Vec<f32>>,
) -> Conv1d<B> {
    let in_channels_per_group = in_channels / groups;

    let conv = Conv1dConfig::new(in_channels, out_channels, kernel_size)
        .with_stride(stride)
        .with_padding(padding)
        .with_dilation(dilation)
        .with_groups(groups)
        .init(device);

    if let (Some(g_data), Some(v_data)) = (original0, original1) {
        let g = Tensor::<B, 3>::from_data(
            burn::tensor::TensorData::new(g_data, [out_channels, 1, 1]),
            device,
        );

        let v = Tensor::<B, 3>::from_data(
            burn::tensor::TensorData::new(
                v_data,
                [out_channels, in_channels_per_group, kernel_size],
            ),
            device,
        );

        let v_norm_sq = v.clone().powf_scalar(2.0).sum_dim(2).sum_dim(1);
        let v_norm = v_norm_sq.sqrt();
        let v_norm = v_norm.unsqueeze_dim::<3>(2).unsqueeze_dim::<3>(2);

        let _weight = g * v / (v_norm + 1e-12);
    }

    conv
}

/// Compute weight from weight normalization decomposition
pub fn compute_weight_norm_weight<B: Backend>(g: &Tensor<B, 3>, v: &Tensor<B, 3>) -> Tensor<B, 3> {
    let v_norm_sq = v.clone().powf_scalar(2.0).sum_dim(2).sum_dim(1);
    let v_norm = v_norm_sq.sqrt();
    let v_norm = v_norm.unsqueeze_dim::<3>(2).unsqueeze_dim::<3>(2);

    g.clone() * v.clone() / (v_norm + 1e-12)
}