use std::path::Path;
use anyhow::{Context, Result};
use crate::codec::{NeuCodecDecoder, NeuCodecEncoder, SAMPLE_RATE};
use crate::npy;
#[cfg(feature = "backbone")]
use crate::tokens;
#[cfg(all(feature = "backbone", feature = "espeak"))]
use crate::phonemize;
#[cfg(feature = "backbone")]
use crate::backbone::{BackboneModel, DEFAULT_N_CTX};
#[derive(Debug, Clone)]
pub struct GenerationConfig {
pub max_new_tokens: u32,
}
impl Default for GenerationConfig {
fn default() -> Self {
Self { max_new_tokens: 2048 }
}
}
pub struct NeuTTS {
#[cfg(feature = "backbone")]
pub backbone: BackboneModel,
pub codec: NeuCodecDecoder,
pub language: String,
pub config: GenerationConfig,
}
impl NeuTTS {
#[cfg(feature = "backbone")]
pub fn load_with_decoder(
backbone_path: &Path,
decoder_path: &Path,
language: &str,
) -> Result<Self> {
eprintln!("[neutts] Loading backbone: {}", backbone_path.display());
let backbone = BackboneModel::load(backbone_path, DEFAULT_N_CTX)
.context("Failed to load backbone")?;
eprintln!("[neutts] Loading NeuCodec decoder: {}", decoder_path.display());
let codec = NeuCodecDecoder::from_file(decoder_path)
.with_context(|| format!("Failed to load NeuCodec decoder from {}", decoder_path.display()))?;
Ok(Self {
backbone,
codec,
language: language.to_string(),
config: GenerationConfig::default(),
})
}
#[cfg(feature = "backbone")]
pub fn load(backbone_path: &Path, language: &str) -> Result<Self> {
let decoder_path = std::path::Path::new("models/neucodec_decoder.safetensors");
Self::load_with_decoder(backbone_path, decoder_path, language)
}
#[cfg(not(feature = "backbone"))]
pub fn load_codec_only() -> Result<Self> {
let codec = NeuCodecDecoder::new()
.context("Failed to initialise NeuCodec Burn decoder")?;
Ok(Self {
codec,
language: "en-us".to_string(),
config: GenerationConfig::default(),
})
}
pub fn load_ref_codes(&self, path: &Path) -> Result<Vec<i32>> {
npy::load_npy_i32(path)
.with_context(|| format!("Failed to load reference codes: {}", path.display()))
}
pub fn load_ref_codes_from_bytes(&self, bytes: &[u8]) -> Result<Vec<i32>> {
npy::parse_npy(bytes)
.context("Failed to parse embedded NPY reference codes")?
.into_i32()
.context("Failed to convert embedded NPY to i32")
}
pub fn encode_reference(&self, wav_path: &Path, encoder: &NeuCodecEncoder) -> Result<Vec<i32>> {
encoder.encode_wav(wav_path)
.with_context(|| format!("Failed to encode reference audio: {}", wav_path.display()))
}
pub fn save_ref_codes(&self, codes: &[i32], path: &Path) -> Result<()> {
npy::write_npy_i32(path, codes)
.with_context(|| format!("Failed to save reference codes: {}", path.display()))
}
#[cfg(all(feature = "backbone", feature = "espeak"))]
pub fn infer(&self, text: &str, ref_codes: &[i32], ref_text: &str) -> Result<Vec<f32>> {
let ref_phones = phonemize::phonemize(ref_text, &self.language)
.context("Phonemisation of ref_text failed")?;
let input_phones = phonemize::phonemize(text, &self.language)
.context("Phonemisation of input text failed")?;
self.infer_from_ipa(&input_phones, ref_codes, &ref_phones)
}
#[cfg(feature = "backbone")]
pub fn infer_from_ipa(
&self,
input_ipa: &str,
ref_codes: &[i32],
ref_ipa: &str,
) -> Result<Vec<f32>> {
let prompt = tokens::build_prompt(ref_ipa, input_ipa, ref_codes);
let generated = self.backbone
.generate(&prompt, self.config.max_new_tokens)
.context("Backbone generation failed")?;
let speech_ids = tokens::extract_ids(&generated);
if speech_ids.is_empty() {
anyhow::bail!(
"No speech tokens found in backbone output.\n\
Prompt may have exceeded the context window, or the model produced no output.\n\
Generated text snippet: {:?}",
&generated[..generated.len().min(200)]
);
}
self.codec.decode(&speech_ids)
.context("NeuCodec Burn decode failed")
}
pub fn decode_tokens(&self, speech_ids: &[i32]) -> Result<Vec<f32>> {
self.codec.decode(speech_ids).context("NeuCodec Burn decode failed")
}
pub fn write_wav(&self, audio: &[f32], output_path: &Path) -> Result<()> {
let peak = audio.iter().map(|&s| s.abs()).fold(0.0f32, f32::max);
let scale = if peak > 1.0 { 1.0 / peak } else { 1.0 };
let spec = hound::WavSpec {
channels: 1,
sample_rate: SAMPLE_RATE,
bits_per_sample: 16,
sample_format: hound::SampleFormat::Int,
};
let mut writer = hound::WavWriter::create(output_path, spec)
.with_context(|| format!("Cannot create WAV: {}", output_path.display()))?;
for &s in audio {
let s16 = (s * scale * i16::MAX as f32)
.clamp(i16::MIN as f32, i16::MAX as f32) as i16;
writer.write_sample(s16).context("WAV write error")?;
}
writer.finalize().context("WAV finalise error")?;
println!(
"Saved {} samples ({:.2} s) to {} [peak={peak:.4}, scale={scale:.4}]",
audio.len(),
audio.len() as f32 / SAMPLE_RATE as f32,
output_path.display()
);
Ok(())
}
pub fn to_wav_bytes(&self, audio: &[f32]) -> Vec<u8> {
let peak = audio.iter().map(|&s| s.abs()).fold(0.0f32, f32::max);
let scale = if peak > 1.0 { 1.0 / peak } else { 1.0 };
let num_channels: u16 = 1;
let bits_per_sample: u16 = 16;
let sample_rate: u32 = SAMPLE_RATE;
let byte_rate: u32 = sample_rate * num_channels as u32 * bits_per_sample as u32 / 8;
let block_align: u16 = num_channels * bits_per_sample / 8;
let data_size: u32 = (audio.len() * 2) as u32;
let mut buf = Vec::with_capacity(44 + audio.len() * 2);
buf.extend_from_slice(b"RIFF");
buf.extend_from_slice(&(36 + data_size).to_le_bytes()); buf.extend_from_slice(b"WAVE");
buf.extend_from_slice(b"fmt ");
buf.extend_from_slice(&16u32.to_le_bytes()); buf.extend_from_slice(&1u16.to_le_bytes()); buf.extend_from_slice(&num_channels.to_le_bytes());
buf.extend_from_slice(&sample_rate.to_le_bytes());
buf.extend_from_slice(&byte_rate.to_le_bytes());
buf.extend_from_slice(&block_align.to_le_bytes());
buf.extend_from_slice(&bits_per_sample.to_le_bytes());
buf.extend_from_slice(b"data");
buf.extend_from_slice(&data_size.to_le_bytes());
for &s in audio {
let s16 = (s * scale * i16::MAX as f32)
.clamp(i16::MIN as f32, i16::MAX as f32) as i16;
buf.extend_from_slice(&s16.to_le_bytes());
}
buf
}
#[cfg(all(feature = "backbone", feature = "espeak"))]
pub fn infer_to_file(
&self,
text: &str,
ref_codes: &[i32],
ref_text: &str,
output_path: &Path,
) -> Result<()> {
let audio = self.infer(text, ref_codes, ref_text)?;
self.write_wav(&audio, output_path)
}
}