pub mod acoustic;
pub mod data_loader;
pub mod g2p;
pub mod progress;
pub mod vocoder;
use crate::GlobalOptions;
use clap::Subcommand;
use std::path::PathBuf;
use voirs_sdk::Result;
#[derive(Debug, Clone, Subcommand)]
pub enum TrainCommands {
Vocoder {
#[arg(long, default_value = "diffwave")]
model_type: String,
#[arg(long)]
data: PathBuf,
#[arg(short, long, default_value = "checkpoints/vocoder")]
output: PathBuf,
#[arg(short, long)]
config: Option<PathBuf>,
#[arg(long, default_value = "1000")]
epochs: usize,
#[arg(long, default_value = "16")]
batch_size: usize,
#[arg(long, default_value = "0.0002")]
lr: f64,
#[arg(long, default_value = "none")]
lr_scheduler: String,
#[arg(long, default_value = "100")]
lr_step_size: usize,
#[arg(long, default_value = "0.1")]
lr_gamma: f64,
#[arg(long)]
early_stopping: bool,
#[arg(long, default_value = "50")]
patience: usize,
#[arg(long, default_value = "0.0001")]
min_delta: f64,
#[arg(long, default_value = "5")]
val_frequency: usize,
#[arg(long, default_value = "0")]
warmup_steps: usize,
#[arg(long, default_value = "1.0")]
grad_clip: f64,
#[arg(long, default_value = "10")]
save_frequency: usize,
#[arg(long)]
resume: Option<PathBuf>,
#[arg(long)]
gpu: bool,
},
Acoustic {
#[arg(long, default_value = "vits")]
model_type: String,
#[arg(long)]
data: PathBuf,
#[arg(short, long, default_value = "checkpoints/acoustic")]
output: PathBuf,
#[arg(short, long)]
config: Option<PathBuf>,
#[arg(long, default_value = "500")]
epochs: usize,
#[arg(long, default_value = "32")]
batch_size: usize,
#[arg(long, default_value = "0.0001")]
lr: f64,
#[arg(long)]
resume: Option<PathBuf>,
#[arg(long)]
gpu: bool,
},
G2p {
#[arg(long, default_value = "en")]
language: String,
#[arg(long)]
dictionary: PathBuf,
#[arg(short, long, default_value = "models/g2p.safetensors")]
output: PathBuf,
#[arg(short, long)]
config: Option<PathBuf>,
#[arg(long, default_value = "100")]
epochs: usize,
#[arg(long, default_value = "0.001")]
lr: f64,
},
}
#[derive(Debug, Clone)]
pub struct TrainingConfig {
pub lr_scheduler: String,
pub lr_step_size: usize,
pub lr_gamma: f64,
pub early_stopping: bool,
pub patience: usize,
pub min_delta: f64,
pub val_frequency: usize,
pub warmup_steps: usize,
pub grad_clip: f64,
pub save_frequency: usize,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
lr_scheduler: "none".to_string(),
lr_step_size: 100,
lr_gamma: 0.1,
early_stopping: false,
patience: 50,
min_delta: 0.0001,
val_frequency: 5,
warmup_steps: 0,
grad_clip: 1.0,
save_frequency: 10,
}
}
}
pub async fn execute_train_command(command: TrainCommands, global: &GlobalOptions) -> Result<()> {
match command {
TrainCommands::Vocoder {
model_type,
data,
output,
config,
epochs,
batch_size,
lr,
lr_scheduler,
lr_step_size,
lr_gamma,
early_stopping,
patience,
min_delta,
val_frequency,
warmup_steps,
grad_clip,
save_frequency,
resume,
gpu,
} => {
let training_config = TrainingConfig {
lr_scheduler,
lr_step_size,
lr_gamma,
early_stopping,
patience,
min_delta,
val_frequency,
warmup_steps,
grad_clip,
save_frequency,
};
let args = vocoder::VocoderTrainingArgs {
model_type,
data,
output,
config,
epochs,
batch_size,
lr,
resume,
use_gpu: gpu || global.gpu,
training_config,
};
vocoder::run_train_vocoder(args, global).await
}
TrainCommands::Acoustic {
model_type,
data,
output,
config,
epochs,
batch_size,
lr,
resume,
gpu,
} => {
let args = acoustic::AcousticModelTrainingArgs {
model_type: model_type.clone(),
data: data.clone(),
output: output.clone(),
config: config.clone(),
epochs,
batch_size,
lr,
resume: resume.clone(),
use_gpu: gpu || global.gpu,
};
acoustic::run_train_acoustic(args, global).await
}
TrainCommands::G2p {
language,
dictionary,
output,
config,
epochs,
lr,
} => g2p::run_train_g2p(language, dictionary, output, config, epochs, lr, global).await,
}
}