Skip to main content

rlx_neutts/
runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! High-level NeuTTS runner — backbone + NeuCodec decoder.
17
18use std::path::Path;
19
20use anyhow::{Context, Result};
21
22use crate::decoder::NeuCodecDecoder;
23use crate::tokens;
24
25#[cfg(feature = "llama")]
26use crate::backbone::{BackboneModel, DEFAULT_N_CTX};
27
28/// Generation hyper-parameters for the GGUF backbone.
29#[derive(Debug, Clone)]
30pub struct GenerationConfig {
31    pub max_new_tokens: u32,
32}
33
34impl Default for GenerationConfig {
35    fn default() -> Self {
36        Self {
37            max_new_tokens: 2048,
38        }
39    }
40}
41
42/// NeuTTS handle: GGUF backbone (optional) + NeuCodec decoder.
43pub struct NeuTTS {
44    #[cfg(feature = "llama")]
45    pub backbone: BackboneModel,
46    pub codec: NeuCodecDecoder,
47    pub language: String,
48    pub config: GenerationConfig,
49}
50
51impl NeuTTS {
52    #[cfg(feature = "llama")]
53    pub fn load_with_decoder(
54        backbone_path: &Path,
55        decoder_path: &Path,
56        language: &str,
57    ) -> Result<Self> {
58        eprintln!("[neutts] Loading backbone: {}", backbone_path.display());
59        let backbone =
60            BackboneModel::load(backbone_path, DEFAULT_N_CTX).context("Failed to load backbone")?;
61
62        eprintln!(
63            "[neutts] Loading NeuCodec decoder: {}",
64            decoder_path.display()
65        );
66        let codec = NeuCodecDecoder::from_file(decoder_path).with_context(|| {
67            format!(
68                "Failed to load NeuCodec decoder from {}",
69                decoder_path.display()
70            )
71        })?;
72
73        Ok(Self {
74            backbone,
75            codec,
76            language: language.to_string(),
77            config: GenerationConfig::default(),
78        })
79    }
80
81    #[cfg(feature = "llama")]
82    pub fn load(backbone_path: &Path, language: &str) -> Result<Self> {
83        let decoder_path = crate::decoder::decoder_weights_path()?;
84        Self::load_with_decoder(backbone_path, &decoder_path, language)
85    }
86
87    #[cfg(not(feature = "llama"))]
88    pub fn load_codec_only() -> Result<Self> {
89        let codec = NeuCodecDecoder::new().context("Failed to initialise NeuCodec decoder")?;
90        Ok(Self {
91            codec,
92            language: "en-us".to_string(),
93            config: GenerationConfig::default(),
94        })
95    }
96
97    #[cfg(feature = "llama")]
98    pub fn infer_from_ipa(
99        &self,
100        input_ipa: &str,
101        ref_codes: &[i32],
102        ref_ipa: &str,
103    ) -> Result<Vec<f32>> {
104        let prompt = tokens::build_prompt(ref_ipa, input_ipa, ref_codes);
105        let generated = self
106            .backbone
107            .generate(&prompt, self.config.max_new_tokens)
108            .context("Backbone generation failed")?;
109
110        let speech_ids = tokens::extract_ids(&generated);
111        if speech_ids.is_empty() {
112            anyhow::bail!(
113                "No speech tokens in backbone output. Snippet: {:?}",
114                &generated[..generated.len().min(200)]
115            );
116        }
117
118        self.codec
119            .decode(&speech_ids)
120            .context("NeuCodec decode failed")
121    }
122
123    pub fn decode_tokens(&self, speech_ids: &[i32]) -> Result<Vec<f32>> {
124        self.codec
125            .decode(speech_ids)
126            .context("NeuCodec decode failed")
127    }
128}