maolan-generate 0.0.1

Generate music using Maolan and HeartMuLa
Documentation
use anyhow::{Context, Result, anyhow, bail};
use burn::prelude::Backend;
use maolan_generate::BackendChoice;
use maolan_generate::heartmula_runtime;
use serde::Deserialize;
use std::env;
use std::ffi::OsString;
use std::path::{Path, PathBuf};

const DEFAULT_HEARTMULA_MODEL_REPO_DIR: &str =
    "repos/heartmula-burn/artifacts/heartmula-happy-new-year-20260123";

#[derive(Debug, Clone)]
struct Options {
    backend: BackendChoice,
    model_dir: Option<PathBuf>,
    output_json: PathBuf,
    lyrics: String,
    tags: String,
    length: i64,
    topk: usize,
    temperature: f32,
    cfg_scale: f32,
}

#[derive(Debug, Deserialize)]
struct HeartmulaGenConfig {
    text_bos_id: i64,
    text_eos_id: i64,
    audio_eos_id: i64,
    empty_id: i64,
}

fn help_text() -> &'static str {
    "\
heartmula_debug_first_frame

Usage:
  cargo run --release -p maolan-generate --bin heartmula_debug_first_frame -- [options]

Options:
  --backend <cpu|vulkan>
  --model-dir <path>
  --output-json <path>
  --lyrics <text>
  --tags <text>
  --length <int>
  --topk <int>
  --temperature <float>
  --cfg-scale <float>
  -h, --help
"
}

fn parse_options(args: impl IntoIterator<Item = OsString>) -> Result<Options> {
    let mut args = args.into_iter();
    let _program = args.next();
    let mut backend = BackendChoice::Cpu;
    let mut model_dir = None;
    let mut output_json = PathBuf::from("heartmula_first_frame.json");
    let mut lyrics = None;
    let mut tags = Some(heartmula_runtime::default_tags().to_string());
    let mut length = 2000_i64;
    let mut topk = 50_usize;
    let mut temperature = 1.0_f32;
    let mut cfg_scale = 6.0_f32;

    while let Some(arg) = args.next() {
        let arg = arg
            .into_string()
            .map_err(|_| anyhow!("arguments must be valid UTF-8"))?;
        if matches!(arg.as_str(), "-h" | "--help") {
            bail!(help_text());
        }
        match arg.as_str() {
            "--backend" => {
                let value = args
                    .next()
                    .ok_or_else(|| anyhow!("missing value after --backend"))?
                    .into_string()
                    .map_err(|_| anyhow!("backend value must be valid UTF-8"))?;
                backend = match value.as_str() {
                    "cpu" => BackendChoice::Cpu,
                    "vulkan" => BackendChoice::Vulkan,
                    _ => bail!("unsupported backend '{value}', expected cpu or vulkan"),
                };
            }
            "--model-dir" => {
                model_dir = Some(PathBuf::from(
                    args.next()
                        .ok_or_else(|| anyhow!("missing value after --model-dir"))?,
                ));
            }
            "--output-json" => {
                output_json = PathBuf::from(
                    args.next()
                        .ok_or_else(|| anyhow!("missing value after --output-json"))?,
                );
            }
            "--lyrics" => {
                lyrics = Some(
                    args.next()
                        .ok_or_else(|| anyhow!("missing value after --lyrics"))?
                        .into_string()
                        .map_err(|_| anyhow!("lyrics value must be valid UTF-8"))?,
                );
            }
            "--tags" => {
                tags = Some(
                    args.next()
                        .ok_or_else(|| anyhow!("missing value after --tags"))?
                        .into_string()
                        .map_err(|_| anyhow!("tags value must be valid UTF-8"))?,
                );
            }
            "--length" => {
                let value = args
                    .next()
                    .ok_or_else(|| anyhow!("missing value after --length"))?
                    .into_string()
                    .map_err(|_| anyhow!("length value must be valid UTF-8"))?;
                length = value
                    .parse::<i64>()
                    .map_err(|_| anyhow!("length must be a whole number"))?;
            }
            "--topk" => {
                let value = args
                    .next()
                    .ok_or_else(|| anyhow!("missing value after --topk"))?
                    .into_string()
                    .map_err(|_| anyhow!("topk value must be valid UTF-8"))?;
                topk = value
                    .parse::<usize>()
                    .map_err(|_| anyhow!("topk must be a whole number"))?;
            }
            "--temperature" => {
                let value = args
                    .next()
                    .ok_or_else(|| anyhow!("missing value after --temperature"))?
                    .into_string()
                    .map_err(|_| anyhow!("temperature value must be valid UTF-8"))?;
                temperature = value
                    .parse::<f32>()
                    .map_err(|_| anyhow!("temperature must be a number"))?;
            }
            "--cfg-scale" => {
                let value = args
                    .next()
                    .ok_or_else(|| anyhow!("missing value after --cfg-scale"))?
                    .into_string()
                    .map_err(|_| anyhow!("cfg-scale value must be valid UTF-8"))?;
                cfg_scale = value
                    .parse::<f32>()
                    .map_err(|_| anyhow!("cfg-scale must be a number"))?;
            }
            _ => bail!("unexpected argument '{arg}'"),
        }
    }

    let lyrics = lyrics
        .map(|s| s.trim().to_lowercase())
        .filter(|s| !s.is_empty())
        .ok_or_else(|| anyhow!("--lyrics is required"))?;
    let tags = heartmula_runtime::normalize_tags(tags.as_deref().unwrap_or_default());
    Ok(Options {
        backend,
        model_dir,
        output_json,
        lyrics,
        tags,
        length,
        topk,
        temperature,
        cfg_scale,
    })
}

fn resolve_model_dir(override_dir: Option<&Path>) -> Result<PathBuf> {
    if let Some(path) = override_dir.map(Path::to_path_buf) {
        return Ok(path);
    }
    let home = env::var("HOME")
        .or_else(|_| env::var("USERPROFILE"))
        .unwrap_or_else(|_| "/tmp".to_string());
    Ok(PathBuf::from(home).join(DEFAULT_HEARTMULA_MODEL_REPO_DIR))
}

fn heartmula_raw_bpk_rel() -> &'static str {
    "heartmula.bpk"
}

fn load_gen_config(path: &Path) -> Result<HeartmulaGenConfig> {
    let bytes =
        std::fs::read(path).with_context(|| format!("failed to read {}", path.display()))?;
    serde_json::from_slice(&bytes).with_context(|| format!("failed to parse {}", path.display()))
}

fn run_with_backend<B: Backend>(options: &Options) -> Result<()>
where
    B::Device: Default,
{
    let model_dir = resolve_model_dir(options.model_dir.as_deref())?;
    let heartmula_raw_bpk = model_dir.join(heartmula_raw_bpk_rel());
    let tokenizer_json = model_dir.join("tokenizer.json");
    let gen_config_json = model_dir.join("gen_config.json");
    let device = Default::default();
    let config = load_gen_config(&gen_config_json)?;
    let lyrics_ids = heartmula_runtime::tokenize_text(&tokenizer_json, &options.lyrics)?;
    let tags_ids = heartmula_runtime::tokenize_text(&tokenizer_json, &options.tags)?;
    let model = heartmula_runtime::HeartmulaModel::<B>::from_burnpack(
        &heartmula_raw_bpk,
        &device,
        128_256,
        8_197,
    )?;
    let mut generation_config = heartmula_runtime::HeartmulaGenerationConfig {
        text_bos_id: config.text_bos_id,
        text_eos_id: config.text_eos_id,
        audio_eos_id: config.audio_eos_id,
        empty_id: config.empty_id,
        lyrics_ids: &lyrics_ids,
        tags_ids: &tags_ids,
        max_audio_frames: ((options.length.max(1) as usize) / 80).max(1),
        temperature: options.temperature,
        topk: options.topk,
        cfg_scale: options.cfg_scale,
        progress_callback: None,
    };
    let debug = model.debug_first_frame(&device, &mut generation_config)?;
    std::fs::write(&options.output_json, serde_json::to_vec_pretty(&debug)?)
        .with_context(|| format!("failed to write {}", options.output_json.display()))?;
    println!("debug_json={}", options.output_json.display());
    Ok(())
}

fn main() -> Result<()> {
    let options = match parse_options(env::args_os()) {
        Ok(options) => options,
        Err(err) if err.to_string() == help_text() => {
            println!("{}", help_text());
            return Ok(());
        }
        Err(err) => return Err(err),
    };
    match options.backend {
        BackendChoice::Cpu => run_with_backend::<burn::backend::NdArray<f32>>(&options),
        BackendChoice::Vulkan => {
            let device = burn::backend::wgpu::WgpuDevice::default();
            burn::backend::wgpu::init_setup::<burn::backend::wgpu::graphics::Vulkan>(
                &device,
                Default::default(),
            );
            run_with_backend::<burn::backend::Wgpu<f32, i64, u32>>(&options)
        }
    }
}