#[cfg(feature = "audio-loading")]
use crate::model::voice_clone::VoiceClonePromptItem;
use anyhow::{Context, Result, bail};
use candle_core::DType;
#[cfg(feature = "audio-loading")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "audio-loading")]
use std::fs;
use std::path::PathBuf;
#[cfg(feature = "audio-loading")]
const VOICE_PROMPT_VERSION: u32 = 1;
#[cfg(feature = "audio-loading")]
#[derive(Debug, Clone, Serialize, Deserialize)]
struct VoicePromptFile {
version: u32,
x_vector_only: bool,
ref_text: Option<String>,
prompt_data: String,
}
#[cfg(feature = "audio-loading")]
pub fn save_voice_prompt(
prompt: &VoiceClonePromptItem,
path: &PathBuf,
x_vector_only: bool,
ref_text: Option<&str>,
) -> Result<()> {
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
use safetensors::serialize;
use safetensors::tensor::TensorView;
use std::collections::HashMap;
let spk_data = prompt
.ref_spk_embedding
.to_dtype(candle_core::DType::F32)?
.to_vec1::<f32>()?;
let spk_shape: Vec<usize> = prompt.ref_spk_embedding.dims().to_vec();
let mut views: HashMap<String, TensorView<'_>> = HashMap::new();
views.insert(
"ref_spk_embedding".to_string(),
TensorView::new(
safetensors::Dtype::F32,
spk_shape,
bytemuck::cast_slice(&spk_data),
)?,
);
let code_data: Option<Vec<i64>>;
let code_shape: Option<Vec<usize>>;
if let Some(ref code) = prompt.ref_code {
let code_i64 = code.to_dtype(candle_core::DType::I64)?;
code_data = Some(code_i64.to_vec2::<i64>()?.into_iter().flatten().collect());
code_shape = Some(code.dims().to_vec());
} else {
code_data = None;
code_shape = None;
}
if let (Some(data), Some(shape)) = (&code_data, &code_shape) {
views.insert(
"ref_code".to_string(),
TensorView::new(
safetensors::Dtype::I64,
shape.clone(),
bytemuck::cast_slice(data),
)?,
);
}
let buffer = serialize(&views, None)?;
let prompt_file = VoicePromptFile {
version: VOICE_PROMPT_VERSION,
x_vector_only,
ref_text: ref_text.map(|s| s.to_string()),
prompt_data: BASE64.encode(&buffer),
};
let json = serde_json::to_string_pretty(&prompt_file)?;
fs::write(path, json)?;
Ok(())
}
#[cfg(feature = "audio-loading")]
pub fn load_voice_prompt(
path: &PathBuf,
device: &candle_core::Device,
dtype: candle_core::DType,
) -> Result<VoiceClonePromptItem> {
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
use candle_core::safetensors::load_buffer;
let content = fs::read_to_string(path)
.with_context(|| format!("Failed to read voice prompt file: {:?}", path))?;
let prompt_file: VoicePromptFile = serde_json::from_str(&content)
.with_context(|| format!("Failed to parse voice prompt file: {:?}", path))?;
if prompt_file.version != VOICE_PROMPT_VERSION {
bail!(
"Voice prompt file version mismatch: expected {}, got {}",
VOICE_PROMPT_VERSION,
prompt_file.version
);
}
let buffer = BASE64
.decode(&prompt_file.prompt_data)
.context("Failed to decode voice prompt data")?;
let tensors = load_buffer(&buffer, device)?;
let ref_spk_embedding = tensors
.get("ref_spk_embedding")
.context("Voice prompt file missing ref_spk_embedding")?
.to_dtype(dtype)?;
let ref_code = tensors
.get("ref_code")
.map(|t| t.to_dtype(candle_core::DType::U32))
.transpose()?;
let icl_mode =
ref_code.is_some() && prompt_file.ref_text.is_some() && !prompt_file.x_vector_only;
Ok(VoiceClonePromptItem {
ref_code,
ref_spk_embedding,
x_vector_only_mode: prompt_file.x_vector_only,
icl_mode,
ref_text: prompt_file.ref_text,
})
}
pub fn write_wav(path: &PathBuf, audio: &candle_core::Tensor, sample_rate: usize) -> Result<()> {
use hound::{SampleFormat, WavSpec, WavWriter};
let audio_flat = audio.flatten_all()?;
let num_samples = audio_flat.dim(0)?;
if num_samples == 0 {
bail!(
"No audio was generated. The model produced 0 samples. \
This may indicate an issue with the model weights, configuration, \
or that the input text is too short/empty."
);
}
let spec = WavSpec {
channels: 1,
sample_rate: sample_rate as u32,
bits_per_sample: 16,
sample_format: SampleFormat::Int,
};
let mut writer = WavWriter::create(path, spec).context("Failed to create WAV file")?;
let samples = audio_flat.to_dtype(DType::F32)?.to_vec1::<f32>()?;
for sample in samples {
let sample = sample.clamp(-1.0, 1.0);
let sample_i16 = (sample * 32767.0) as i16;
writer.write_sample(sample_i16)?;
}
writer.finalize()?;
Ok(())
}