maolan-generate 0.0.1

Generate music using Maolan and HeartMuLa
Documentation
use anyhow::{Context, Result, bail};
use burn::tensor::DType;
use burn_store::{BurnpackStore, BurnpackWriter, ModuleStore, TensorSnapshot};
use std::env;
use std::path::PathBuf;

fn main() -> Result<()> {
    let mut args = env::args_os();
    let _program = args.next();
    let input = PathBuf::from(args.next().context("missing input .bpk path")?);
    let output = PathBuf::from(args.next().context("missing output .bpk path")?);

    if args.next().is_some() {
        bail!("usage: convert_burnpack <input.bpk> <output.bpk>");
    }

    let mut store = BurnpackStore::from_file(&input);
    let snapshots = store
        .get_all_snapshots()
        .context("failed to read burnpack snapshots")?;

    let mut converted = Vec::with_capacity(snapshots.len());
    let mut converted_count = 0usize;

    for snapshot in snapshots.values() {
        let data = snapshot
            .to_data()
            .with_context(|| format!("failed to materialize snapshot {}", snapshot.full_path()))?;

        let converted_snapshot = if data.dtype == DType::F16 {
            let values = data
                .to_vec::<half::f16>()
                .with_context(|| format!("failed to decode F16 tensor {}", snapshot.full_path()))?
                .into_iter()
                .map(f32::from)
                .collect::<Vec<_>>();
            converted_count += 1;
            TensorSnapshot::from_data(
                burn::tensor::TensorData::new(values, data.shape.clone()),
                snapshot.path_stack.clone().unwrap_or_default(),
                snapshot.container_stack.clone().unwrap_or_default(),
                snapshot
                    .tensor_id
                    .context("tensor snapshot is missing param id")?,
            )
        } else {
            snapshot.clone()
        };

        converted.push(converted_snapshot);
    }

    BurnpackWriter::new(converted)
        .write_to_file(&output)
        .with_context(|| format!("failed to write {}", output.display()))?;

    eprintln!(
        "converted {} F16 tensors from {} to {}",
        converted_count,
        input.display(),
        output.display()
    );

    Ok(())
}