1use std::path::Path;
19
20use anyhow::{Context, Result};
21
22use crate::decoder::NeuCodecDecoder;
23use crate::tokens;
24
25#[cfg(feature = "llama")]
26use crate::backbone::{BackboneModel, DEFAULT_N_CTX};
27
28#[derive(Debug, Clone)]
30pub struct GenerationConfig {
31 pub max_new_tokens: u32,
32}
33
34impl Default for GenerationConfig {
35 fn default() -> Self {
36 Self {
37 max_new_tokens: 2048,
38 }
39 }
40}
41
42pub struct NeuTTS {
44 #[cfg(feature = "llama")]
45 pub backbone: BackboneModel,
46 pub codec: NeuCodecDecoder,
47 pub language: String,
48 pub config: GenerationConfig,
49}
50
51impl NeuTTS {
52 #[cfg(feature = "llama")]
53 pub fn load_with_decoder(
54 backbone_path: &Path,
55 decoder_path: &Path,
56 language: &str,
57 ) -> Result<Self> {
58 eprintln!("[neutts] Loading backbone: {}", backbone_path.display());
59 let backbone =
60 BackboneModel::load(backbone_path, DEFAULT_N_CTX).context("Failed to load backbone")?;
61
62 eprintln!(
63 "[neutts] Loading NeuCodec decoder: {}",
64 decoder_path.display()
65 );
66 let codec = NeuCodecDecoder::from_file(decoder_path).with_context(|| {
67 format!(
68 "Failed to load NeuCodec decoder from {}",
69 decoder_path.display()
70 )
71 })?;
72
73 Ok(Self {
74 backbone,
75 codec,
76 language: language.to_string(),
77 config: GenerationConfig::default(),
78 })
79 }
80
81 #[cfg(feature = "llama")]
82 pub fn load(backbone_path: &Path, language: &str) -> Result<Self> {
83 let decoder_path = crate::decoder::decoder_weights_path()?;
84 Self::load_with_decoder(backbone_path, &decoder_path, language)
85 }
86
87 #[cfg(not(feature = "llama"))]
88 pub fn load_codec_only() -> Result<Self> {
89 let codec = NeuCodecDecoder::new().context("Failed to initialise NeuCodec decoder")?;
90 Ok(Self {
91 codec,
92 language: "en-us".to_string(),
93 config: GenerationConfig::default(),
94 })
95 }
96
97 #[cfg(feature = "llama")]
98 pub fn infer_from_ipa(
99 &self,
100 input_ipa: &str,
101 ref_codes: &[i32],
102 ref_ipa: &str,
103 ) -> Result<Vec<f32>> {
104 let prompt = tokens::build_prompt(ref_ipa, input_ipa, ref_codes);
105 let generated = self
106 .backbone
107 .generate(&prompt, self.config.max_new_tokens)
108 .context("Backbone generation failed")?;
109
110 let speech_ids = tokens::extract_ids(&generated);
111 if speech_ids.is_empty() {
112 anyhow::bail!(
113 "No speech tokens in backbone output. Snippet: {:?}",
114 &generated[..generated.len().min(200)]
115 );
116 }
117
118 self.codec
119 .decode(&speech_ids)
120 .context("NeuCodec decode failed")
121 }
122
123 pub fn decode_tokens(&self, speech_ids: &[i32]) -> Result<Vec<f32>> {
124 self.codec
125 .decode(speech_ids)
126 .context("NeuCodec decode failed")
127 }
128}