use anyhow::anyhow;
use clap::{Parser, ValueEnum};
use directories::ProjectDirs;
use std::fmt::{Display, Formatter};
use std::path::PathBuf;
use tracing::warn;
use crate::backend::*;
use crate::onnxruntime_lib;
use crate::storage::*;
use crate::terminal::*;
use crate::{gpu, musicgen_models};
pub const INPUT_IDS_BATCH_PER_SECOND: usize = 50;
#[derive(Clone, Copy, ValueEnum)]
pub enum Model {
Small,
SmallFp16,
SmallQuant,
Medium,
MediumFp16,
MediumQuant,
Large,
}
impl Display for Model {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Model::Small => write!(f, "MusicGen Small"),
Model::SmallFp16 => write!(f, "MusicGen Small Fp16"),
Model::SmallQuant => write!(f, "MusicGen Small Quantized"),
Model::Medium => write!(f, "MusicGen Medium"),
Model::MediumFp16 => write!(f, "MusicGen Medium Fp16"),
Model::MediumQuant => write!(f, "MusicGen Medium Quantized"),
Model::Large => write!(f, "MusicGen Large"),
}
}
}
#[derive(Parser)]
#[command(name = "MusicGPT")]
#[command(version, about, long_about = None)]
struct Args {
#[arg(default_value = "")]
prompt: String,
#[arg(long, default_value = "small")]
model: Model,
#[arg(long, default_value = "false")]
use_split_decoder: bool,
#[arg(long, default_value = "false")]
force_download: bool,
#[arg(long, default_value = None)]
data_path: Option<PathBuf>,
#[arg(long, default_value = "false")]
gpu: bool,
#[arg(long, default_value = "10")]
secs: usize,
#[arg(long, default_value = "musicgpt-generated.wav")]
output: String,
#[arg(long, default_value = "false")]
no_playback: bool,
#[arg(long, default_value = "false")]
no_interactive: bool,
#[arg(long, default_value = "false")]
ui_no_open: bool,
#[arg(long, default_value = "8642")]
ui_port: usize,
#[arg(long, default_value = "false")]
ui_expose: bool,
}
impl Args {
fn validate(&self) -> anyhow::Result<()> {
if self.secs < 1 {
return Err(anyhow!("--secs must > 0"));
}
if self.secs > 30 {
return Err(anyhow!("--secs must <= 30"));
}
if self.no_interactive && self.prompt.is_empty() {
return Err(anyhow!(
"A prompt must be provided when not in interactive mode"
));
}
Ok(())
}
}
pub async fn cli() -> anyhow::Result<()> {
let args = Args::parse();
args.validate()?;
let storage = AppFs::new(
args.data_path.unwrap_or(
ProjectDirs::from("com", "gabotechs", "musicgpt")
.expect("Could not load project directory")
.data_dir()
.into(),
),
);
let root = storage.root.clone();
let mut ort_builder = onnxruntime_lib::init::init(storage.clone()).await?;
let device = if args.gpu {
warn!("GPU support is experimental, it might not work on most platforms");
let (gpu_device, provider) = gpu::init_gpu()?;
ort_builder = ort_builder.with_execution_providers(&[provider]);
gpu_device
} else {
"Cpu"
};
ort_builder.commit()?;
let musicgen_models = musicgen_models::MusicGenModels::new(
storage.clone(),
args.model,
args.use_split_decoder,
args.force_download,
)
.await?;
if args.prompt.is_empty() {
run_web_server(
root,
storage,
musicgen_models,
RunWebServerOptions {
name: args.model.to_string(),
device: device.to_string(),
port: args.ui_port,
auto_open: true,
expose: args.ui_expose,
},
)
.await
} else {
run_terminal_loop(
root,
musicgen_models,
RunTerminalOptions {
init_prompt: args.prompt,
init_secs: args.secs,
init_output: args.output,
no_playback: args.no_playback,
no_interactive: args.no_interactive,
},
)
.await
}
}