1use std::path::Path;
19
20use anyhow::{Context, Result};
21use rlx_runtime::Device;
22
23use crate::decoder::NeuCodecDecoder;
24use crate::tokens;
25
26#[cfg(feature = "llama")]
27use crate::backbone::{BackboneModel, DEFAULT_N_CTX};
28
29#[derive(Debug, Clone)]
31pub struct GenerationConfig {
32 pub max_new_tokens: u32,
33}
34
35impl Default for GenerationConfig {
36 fn default() -> Self {
37 Self {
38 max_new_tokens: 2048,
39 }
40 }
41}
42
43pub struct NeuTTS {
45 #[cfg(feature = "llama")]
46 pub backbone: BackboneModel,
47 pub codec: NeuCodecDecoder,
48 pub language: String,
49 pub config: GenerationConfig,
50}
51
52impl NeuTTS {
53 #[cfg(feature = "llama")]
54 pub fn load_with_decoder(
55 backbone_path: &Path,
56 decoder_path: &Path,
57 language: &str,
58 ) -> Result<Self> {
59 Self::load_with_decoder_on(backbone_path, decoder_path, language, Device::Cpu)
60 }
61
62 #[cfg(feature = "llama")]
63 pub fn load_with_decoder_on(
64 backbone_path: &Path,
65 decoder_path: &Path,
66 language: &str,
67 device: Device,
68 ) -> Result<Self> {
69 eprintln!("[neutts] Loading backbone: {}", backbone_path.display());
70 let backbone = BackboneModel::load_on(backbone_path, DEFAULT_N_CTX, device)
71 .context("Failed to load backbone")?;
72
73 eprintln!(
74 "[neutts] Loading NeuCodec decoder: {}",
75 decoder_path.display()
76 );
77 let codec = NeuCodecDecoder::from_file(decoder_path).with_context(|| {
78 format!(
79 "Failed to load NeuCodec decoder from {}",
80 decoder_path.display()
81 )
82 })?;
83
84 Ok(Self {
85 backbone,
86 codec,
87 language: language.to_string(),
88 config: GenerationConfig::default(),
89 })
90 }
91
92 #[cfg(feature = "llama")]
93 pub fn load(backbone_path: &Path, language: &str) -> Result<Self> {
94 Self::load_on(backbone_path, language, Device::Cpu)
95 }
96
97 #[cfg(feature = "llama")]
98 pub fn load_on(backbone_path: &Path, language: &str, device: Device) -> Result<Self> {
99 let decoder_path = crate::decoder::decoder_weights_path()?;
100 Self::load_with_decoder_on(backbone_path, &decoder_path, language, device)
101 }
102
103 #[cfg(not(feature = "llama"))]
104 pub fn load_codec_only() -> Result<Self> {
105 let codec = NeuCodecDecoder::new().context("Failed to initialise NeuCodec decoder")?;
106 Ok(Self {
107 codec,
108 language: "en-us".to_string(),
109 config: GenerationConfig::default(),
110 })
111 }
112
113 #[cfg(feature = "llama")]
114 pub fn infer_from_ipa(
115 &self,
116 input_ipa: &str,
117 ref_codes: &[i32],
118 ref_ipa: &str,
119 ) -> Result<Vec<f32>> {
120 let prompt = tokens::build_prompt(ref_ipa, input_ipa, ref_codes);
121 let generated = self
122 .backbone
123 .generate(&prompt, self.config.max_new_tokens)
124 .context("Backbone generation failed")?;
125
126 let speech_ids = tokens::extract_ids(&generated);
127 if speech_ids.is_empty() {
128 anyhow::bail!(
129 "No speech tokens in backbone output. Snippet: {:?}",
130 &generated[..generated.len().min(200)]
131 );
132 }
133
134 self.codec
135 .decode(&speech_ids)
136 .context("NeuCodec decode failed")
137 }
138
139 pub fn decode_tokens(&self, speech_ids: &[i32]) -> Result<Vec<f32>> {
140 self.codec
141 .decode(speech_ids)
142 .context("NeuCodec decode failed")
143 }
144}