use crate::GlobalOptions;
use candle_core::{Device, Tensor};
use std::path::{Path, PathBuf};
use voirs_sdk::Result;
use voirs_vocoder::models::diffwave::{DiffWave, SamplingMethod};
#[derive(Debug)]
pub struct VocoderInferenceConfig<'a> {
pub checkpoint: &'a Path,
pub mel_path: Option<&'a Path>,
pub output: &'a Path,
pub steps: usize,
pub quality: Option<&'a str>,
pub batch_input: Option<&'a PathBuf>,
pub batch_output: Option<&'a PathBuf>,
pub show_metrics: bool,
}
#[derive(Debug, Clone, Copy)]
enum QualityPreset {
Fast, Balanced, High, }
impl QualityPreset {
fn from_str(s: &str) -> Result<Self> {
match s.to_lowercase().as_str() {
"fast" => Ok(Self::Fast),
"balanced" => Ok(Self::Balanced),
"high" => Ok(Self::High),
_ => Err(voirs_sdk::VoirsError::config_error(format!(
"Invalid quality preset: {}. Use 'fast', 'balanced', or 'high'",
s
))),
}
}
fn steps(&self) -> usize {
match self {
Self::Fast => 20,
Self::Balanced => 50,
Self::High => 100,
}
}
}
pub async fn run_vocoder_inference(
config: VocoderInferenceConfig<'_>,
global: &GlobalOptions,
) -> Result<()> {
if config.batch_input.is_some() || config.batch_output.is_some() {
if config.batch_input.is_none() || config.batch_output.is_none() {
return Err(voirs_sdk::VoirsError::config_error(
"Batch mode requires both --batch-input and --batch-output",
));
}
return run_batch_inference(
config.checkpoint,
config.batch_input.expect("checked is_none above"),
config.batch_output.expect("checked is_none above"),
config.steps,
config.quality,
config.show_metrics,
global,
)
.await;
}
run_single_inference(
config.checkpoint,
config.mel_path,
config.output,
config.steps,
config.quality,
config.show_metrics,
global,
)
.await
}
async fn run_single_inference(
checkpoint: &Path,
mel_path: Option<&Path>,
output: &Path,
mut steps: usize,
quality: Option<&str>,
show_metrics: bool,
global: &GlobalOptions,
) -> Result<()> {
if let Some(quality_str) = quality {
let preset = QualityPreset::from_str(quality_str)?;
steps = preset.steps();
if !global.quiet {
println!("Using quality preset: {:?} ({} steps)", preset, steps);
}
}
use std::time::Instant;
let total_start = Instant::now();
if !global.quiet {
println!("🎵 VoiRS Vocoder Inference");
println!("═══════════════════════════════════════");
println!("Checkpoint: {}", checkpoint.display());
if let Some(mel) = mel_path {
println!("Mel spec: {}", mel.display());
} else {
println!("Mel spec: <generating dummy>");
}
println!("Output: {}", output.display());
println!("Steps: {}", steps);
println!("═══════════════════════════════════════\n");
}
let device = if global.gpu {
#[cfg(feature = "cuda")]
{
Device::new_cuda(0).unwrap_or(Device::Cpu)
}
#[cfg(not(feature = "cuda"))]
{
if !global.quiet {
println!("⚠️ GPU requested but CUDA not available, using CPU");
}
Device::Cpu
}
} else {
Device::Cpu
};
if !global.quiet {
println!("📦 Loading DiffWave model from checkpoint...");
}
let model = DiffWave::load_from_safetensors(checkpoint, device.clone()).map_err(|e| {
voirs_sdk::VoirsError::config_error(format!("Failed to load DiffWave model: {}", e))
})?;
if !global.quiet {
println!("✓ Model loaded successfully");
println!(" Parameters: {}", model.num_parameters());
println!();
}
let mel_tensor = if let Some(mel_file) = mel_path {
if !global.quiet {
println!("📊 Loading mel spectrogram from file...");
}
load_mel_spectrogram(mel_file, &device)?
} else {
if !global.quiet {
println!("📊 Generating dummy mel spectrogram...");
}
generate_dummy_mel_spectrogram(&device)?
};
if !global.quiet {
println!("✓ Mel spectrogram ready");
println!(" Shape: {:?}", mel_tensor.dims());
println!();
}
if !global.quiet {
println!("🔄 Running vocoder inference...");
println!(" Sampling method: DDIM");
println!(" Diffusion steps: {}", steps);
}
let sampling_method = SamplingMethod::DDIM { steps, eta: 0.0 };
let audio_tensor = model
.inference(&mel_tensor, sampling_method)
.map_err(|e| voirs_sdk::VoirsError::config_error(format!("Inference failed: {}", e)))?;
if !global.quiet {
println!("✓ Inference complete");
println!(" Audio shape: {:?}", audio_tensor.dims());
println!();
}
if !global.quiet {
println!("💾 Saving audio to {}...", output.display());
}
save_audio_tensor(&audio_tensor, output, 22050)?;
let total_time = total_start.elapsed();
if !global.quiet {
println!("✅ Vocoder inference complete!");
println!(" Output: {}", output.display());
}
if show_metrics {
println!();
println!("╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌");
println!("Performance Metrics:");
println!("╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌");
println!("Total time: {:.3}s", total_time.as_secs_f64());
if let Ok(dims) = audio_tensor.dims3() {
let (_, _, samples) = dims;
let duration_sec = samples as f64 / 22050.0;
let rtf = total_time.as_secs_f64() / duration_sec;
println!("Audio duration: {:.2}s", duration_sec);
println!("Real-time factor: {:.3}x", rtf);
}
println!("╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌");
}
Ok(())
}
fn load_mel_spectrogram(path: &Path, device: &Device) -> Result<Tensor> {
match path.extension().and_then(|e| e.to_str()) {
Some("npy") => load_numpy_file(path, device),
Some("pt") | Some("pth") => load_pytorch_file(path, device),
Some("safetensors") => load_safetensors_file(path, device),
_ => Err(voirs_sdk::VoirsError::UnsupportedFileFormat {
path: path.to_path_buf(),
format: path
.extension()
.and_then(|e| e.to_str())
.unwrap_or("unknown")
.to_string(),
}),
}
}
fn load_numpy_file(path: &Path, device: &Device) -> Result<Tensor> {
let data = std::fs::read(path).map_err(|e| voirs_sdk::VoirsError::IoError {
path: path.to_path_buf(),
operation: voirs_sdk::error::IoOperation::Read,
source: e,
})?;
if data.len() < 10 || &data[0..6] != b"\x93NUMPY" {
return Err(voirs_sdk::VoirsError::config_error(
"Invalid NumPy file: magic number mismatch",
));
}
let major_version = data[6];
let minor_version = data[7];
if major_version != 1 && major_version != 2 {
return Err(voirs_sdk::VoirsError::config_error(format!(
"Unsupported NumPy version: {}.{}",
major_version, minor_version
)));
}
let header_len = if major_version == 1 {
u16::from_le_bytes([data[8], data[9]]) as usize
} else {
u32::from_le_bytes([data[8], data[9], data[10], data[11]]) as usize
};
let header_start = if major_version == 1 { 10 } else { 12 };
let header_end = header_start + header_len;
if data.len() < header_end {
return Err(voirs_sdk::VoirsError::config_error(
"Invalid NumPy file: truncated header",
));
}
let header_str = std::str::from_utf8(&data[header_start..header_end])
.map_err(|_| voirs_sdk::VoirsError::config_error("Invalid NumPy header: not UTF-8"))?;
let shape = parse_numpy_shape(header_str)?;
let dtype = parse_numpy_dtype(header_str)?;
if dtype != "f4" && dtype != "<f4" && dtype != "float32" {
return Err(voirs_sdk::VoirsError::config_error(format!(
"Unsupported NumPy dtype: {}. Only float32 is supported.",
dtype
)));
}
let data_start = header_end;
let num_elements: usize = shape.iter().product();
let expected_bytes = num_elements * 4;
if data.len() < data_start + expected_bytes {
return Err(voirs_sdk::VoirsError::config_error(
"Invalid NumPy file: insufficient data",
));
}
let f32_data: Vec<f32> = data[data_start..data_start + expected_bytes]
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
let tensor = Tensor::from_vec(f32_data, shape.as_slice(), device).map_err(|e| {
voirs_sdk::VoirsError::config_error(format!(
"Failed to create tensor from NumPy data: {}",
e
))
})?;
Ok(tensor)
}
fn parse_numpy_shape(header: &str) -> Result<Vec<usize>> {
let shape_start = header
.find("'shape':")
.or_else(|| header.find("\"shape\":"))
.ok_or_else(|| voirs_sdk::VoirsError::config_error("NumPy header missing 'shape' field"))?;
let shape_str = &header[shape_start..];
let tuple_start = shape_str
.find('(')
.ok_or_else(|| voirs_sdk::VoirsError::config_error("NumPy shape malformed"))?;
let tuple_end = shape_str
.find(')')
.ok_or_else(|| voirs_sdk::VoirsError::config_error("NumPy shape malformed"))?;
let tuple_content = &shape_str[tuple_start + 1..tuple_end];
if tuple_content.trim().is_empty() {
return Ok(vec![1]);
}
let dims: Result<Vec<usize>> = tuple_content
.split(',')
.filter(|s| !s.trim().is_empty())
.map(|s| {
s.trim().parse::<usize>().map_err(|_| {
voirs_sdk::VoirsError::config_error(format!("Invalid dimension: {}", s))
})
})
.collect();
dims
}
fn parse_numpy_dtype(header: &str) -> Result<String> {
let descr_start = header
.find("'descr':")
.or_else(|| header.find("\"descr\":"))
.ok_or_else(|| voirs_sdk::VoirsError::config_error("NumPy header missing 'descr' field"))?;
let descr_str = &header[descr_start..];
let value_start = descr_str
.find('\'')
.or_else(|| descr_str.find('"'))
.ok_or_else(|| voirs_sdk::VoirsError::config_error("NumPy descr malformed"))?;
let value_str = &descr_str[value_start + 1..];
let value_end = value_str
.find('\'')
.or_else(|| value_str.find('"'))
.ok_or_else(|| voirs_sdk::VoirsError::config_error("NumPy descr malformed"))?;
Ok(value_str[..value_end].to_string())
}
fn load_pytorch_file(path: &Path, _device: &Device) -> Result<Tensor> {
Err(voirs_sdk::VoirsError::config_error(format!(
"PyTorch .pt file loading requires Python interop or conversion.\n\
\n\
Alternatives:\n\
1. Convert to NumPy: python -c \"import torch, numpy as np; np.save('output.npy', torch.load('{}').numpy())\"\n\
2. Convert to SafeTensors: Use safetensors.torch.save_file() in Python\n\
3. Use ONNX format: Export model to ONNX and use --input-format onnx\n\
\n\
For native PyTorch support, compile with 'tch-rs' feature (requires libtorch).",
path.display()
)))
}
fn load_safetensors_file(path: &Path, device: &Device) -> Result<Tensor> {
use safetensors::SafeTensors;
let data = std::fs::read(path).map_err(|e| voirs_sdk::VoirsError::IoError {
path: path.to_path_buf(),
operation: voirs_sdk::error::IoOperation::Read,
source: e,
})?;
let tensors = SafeTensors::deserialize(&data).map_err(|e| {
voirs_sdk::VoirsError::config_error(format!("Failed to load SafeTensors: {}", e))
})?;
let names = tensors.names();
let tensor_name = names
.first()
.ok_or_else(|| voirs_sdk::VoirsError::config_error("No tensors found in file"))?;
let tensor_view = tensors
.tensor(tensor_name)
.map_err(|e| voirs_sdk::VoirsError::config_error(format!("Failed to get tensor: {}", e)))?;
let shape: Vec<usize> = tensor_view.shape().to_vec();
let data = tensor_view.data();
let f32_data: Vec<f32> = data
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
let tensor = Tensor::from_vec(f32_data, shape.as_slice(), device).map_err(|e| {
voirs_sdk::VoirsError::config_error(format!("Failed to create tensor: {}", e))
})?;
Ok(tensor)
}
fn generate_dummy_mel_spectrogram(device: &Device) -> Result<Tensor> {
let batch_size = 1;
let mel_channels = 80;
let time_frames = 100;
let data: Vec<f32> = (0..(batch_size * mel_channels * time_frames))
.map(|_| fastrand::f32() * 2.0 - 1.0) .collect();
let tensor =
Tensor::from_vec(data, (batch_size, mel_channels, time_frames), device).map_err(|e| {
voirs_sdk::VoirsError::config_error(format!("Failed to create tensor: {}", e))
})?;
Ok(tensor)
}
fn save_audio_tensor(tensor: &Tensor, output: &Path, sample_rate: u32) -> Result<()> {
use hound::{WavSpec, WavWriter};
let audio_data: Vec<f32> = tensor
.flatten_all()
.map_err(|e| {
voirs_sdk::VoirsError::config_error(format!("Failed to flatten tensor: {}", e))
})?
.to_vec1()
.map_err(|e| {
voirs_sdk::VoirsError::config_error(format!("Failed to convert tensor to vec: {}", e))
})?;
let spec = WavSpec {
channels: 1,
sample_rate,
bits_per_sample: 16,
sample_format: hound::SampleFormat::Int,
};
let mut writer =
WavWriter::create(output, spec).map_err(|e| voirs_sdk::VoirsError::IoError {
path: output.to_path_buf(),
operation: voirs_sdk::error::IoOperation::Write,
source: std::io::Error::other(e),
})?;
for &sample in &audio_data {
let sample_i16 = (sample * 32767.0).clamp(-32768.0, 32767.0) as i16;
writer
.write_sample(sample_i16)
.map_err(|e| voirs_sdk::VoirsError::IoError {
path: output.to_path_buf(),
operation: voirs_sdk::error::IoOperation::Write,
source: std::io::Error::other(e),
})?;
}
writer
.finalize()
.map_err(|e| voirs_sdk::VoirsError::IoError {
path: output.to_path_buf(),
operation: voirs_sdk::error::IoOperation::Write,
source: std::io::Error::other(e),
})?;
Ok(())
}
async fn run_batch_inference(
checkpoint: &Path,
input_dir: &Path,
output_dir: &Path,
mut steps: usize,
quality: Option<&str>,
show_metrics: bool,
global: &GlobalOptions,
) -> Result<()> {
use std::time::Instant;
if let Some(quality_str) = quality {
let preset = QualityPreset::from_str(quality_str)?;
steps = preset.steps();
}
if !global.quiet {
println!("🎵 VoiRS Batch Vocoder Inference");
println!("═══════════════════════════════════════");
println!("Checkpoint: {}", checkpoint.display());
println!("Input dir: {}", input_dir.display());
println!("Output dir: {}", output_dir.display());
println!("Steps: {}", steps);
if let Some(q) = quality {
println!("Quality: {}", q);
}
println!("═══════════════════════════════════════\n");
}
if !input_dir.is_dir() {
return Err(voirs_sdk::VoirsError::config_error(format!(
"Input directory not found: {}",
input_dir.display()
)));
}
std::fs::create_dir_all(output_dir)?;
let mel_files: Vec<_> = std::fs::read_dir(input_dir)
.map_err(|e| voirs_sdk::VoirsError::IoError {
path: input_dir.to_path_buf(),
operation: voirs_sdk::error::IoOperation::Read,
source: e,
})?
.filter_map(|entry| entry.ok())
.map(|entry| entry.path())
.filter(|path| {
path.extension()
.and_then(|e| e.to_str())
.map(|ext| matches!(ext, "npy" | "safetensors" | "pt" | "pth"))
.unwrap_or(false)
})
.collect();
if mel_files.is_empty() {
return Err(voirs_sdk::VoirsError::config_error(
"No mel spectrogram files found in input directory",
));
}
if !global.quiet {
println!("Found {} mel spectrogram files", mel_files.len());
println!();
}
let device = if global.gpu {
#[cfg(feature = "cuda")]
{
Device::new_cuda(0).unwrap_or(Device::Cpu)
}
#[cfg(not(feature = "cuda"))]
{
Device::Cpu
}
} else {
Device::Cpu
};
let model = DiffWave::load_from_safetensors(checkpoint, device.clone())?;
let mut total_time = 0.0;
let mut successful = 0;
let mut failed = 0;
let batch_start = Instant::now();
for (idx, mel_file) in mel_files.iter().enumerate() {
let file_start = Instant::now();
let output_name = mel_file
.file_stem()
.and_then(|n| n.to_str())
.unwrap_or("output");
let output_path = output_dir.join(format!("{}.wav", output_name));
if !global.quiet {
println!(
"[{}/{}] Processing {}...",
idx + 1,
mel_files.len(),
mel_file.display()
);
}
let result =
process_single_mel(&model, mel_file, &output_path, steps, &device, global).await;
let file_time = file_start.elapsed().as_secs_f64();
total_time += file_time;
match result {
Ok(_) => {
successful += 1;
if !global.quiet {
println!(" ✓ Complete in {:.2}s", file_time);
}
}
Err(e) => {
failed += 1;
eprintln!(" ✗ Failed: {}", e);
}
}
}
let total_elapsed = batch_start.elapsed().as_secs_f64();
if !global.quiet || show_metrics {
println!();
println!("╔═══════════════════════════════════════╗");
println!("║ Batch Inference Complete ║");
println!("╠═══════════════════════════════════════╣");
println!("║ Total files: {:<21} ║", mel_files.len());
println!("║ Successful: {:<21} ║", successful);
println!("║ Failed: {:<21} ║", failed);
println!("║ Total time: {:<18.2}s ║", total_elapsed);
println!(
"║ Avg time/file: {:<18.2}s ║",
total_time / mel_files.len() as f64
);
if successful > 0 {
println!(
"║ Throughput: {:<18.2}/s ║",
successful as f64 / total_elapsed
);
}
println!("╚═══════════════════════════════════════╝");
}
Ok(())
}
async fn process_single_mel(
model: &DiffWave,
mel_path: &Path,
output_path: &Path,
steps: usize,
device: &Device,
_global: &GlobalOptions,
) -> Result<()> {
let mel_tensor = load_mel_spectrogram(mel_path, device)?;
let sampling_method = SamplingMethod::DDIM { steps, eta: 0.0 };
let audio_tensor = model
.inference(&mel_tensor, sampling_method)
.map_err(|e| voirs_sdk::VoirsError::config_error(format!("Inference failed: {}", e)))?;
save_audio_tensor(&audio_tensor, output_path, 22050)?;
Ok(())
}