Skip to main content

rlx_neutts/
runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! High-level NeuTTS runner — backbone + NeuCodec decoder.
17
18use std::path::Path;
19
20use anyhow::{Context, Result};
21use rlx_runtime::Device;
22
23use crate::decoder::NeuCodecDecoder;
24use crate::tokens;
25
26#[cfg(feature = "llama")]
27use crate::backbone::{BackboneModel, DEFAULT_N_CTX};
28
29/// Generation hyper-parameters for the GGUF backbone.
30#[derive(Debug, Clone)]
31pub struct GenerationConfig {
32    pub max_new_tokens: u32,
33}
34
35impl Default for GenerationConfig {
36    fn default() -> Self {
37        Self {
38            max_new_tokens: 2048,
39        }
40    }
41}
42
43/// NeuTTS handle: GGUF backbone (optional) + NeuCodec decoder.
44pub struct NeuTTS {
45    #[cfg(feature = "llama")]
46    pub backbone: BackboneModel,
47    pub codec: NeuCodecDecoder,
48    pub language: String,
49    pub config: GenerationConfig,
50}
51
52impl NeuTTS {
53    #[cfg(feature = "llama")]
54    pub fn load_with_decoder(
55        backbone_path: &Path,
56        decoder_path: &Path,
57        language: &str,
58    ) -> Result<Self> {
59        Self::load_with_decoder_on(backbone_path, decoder_path, language, Device::Cpu)
60    }
61
62    #[cfg(feature = "llama")]
63    pub fn load_with_decoder_on(
64        backbone_path: &Path,
65        decoder_path: &Path,
66        language: &str,
67        device: Device,
68    ) -> Result<Self> {
69        eprintln!("[neutts] Loading backbone: {}", backbone_path.display());
70        let backbone = BackboneModel::load_on(backbone_path, DEFAULT_N_CTX, device)
71            .context("Failed to load backbone")?;
72
73        eprintln!(
74            "[neutts] Loading NeuCodec decoder: {}",
75            decoder_path.display()
76        );
77        let codec = NeuCodecDecoder::from_file(decoder_path).with_context(|| {
78            format!(
79                "Failed to load NeuCodec decoder from {}",
80                decoder_path.display()
81            )
82        })?;
83
84        Ok(Self {
85            backbone,
86            codec,
87            language: language.to_string(),
88            config: GenerationConfig::default(),
89        })
90    }
91
92    #[cfg(feature = "llama")]
93    pub fn load(backbone_path: &Path, language: &str) -> Result<Self> {
94        Self::load_on(backbone_path, language, Device::Cpu)
95    }
96
97    #[cfg(feature = "llama")]
98    pub fn load_on(backbone_path: &Path, language: &str, device: Device) -> Result<Self> {
99        let decoder_path = crate::decoder::decoder_weights_path()?;
100        Self::load_with_decoder_on(backbone_path, &decoder_path, language, device)
101    }
102
103    #[cfg(not(feature = "llama"))]
104    pub fn load_codec_only() -> Result<Self> {
105        let codec = NeuCodecDecoder::new().context("Failed to initialise NeuCodec decoder")?;
106        Ok(Self {
107            codec,
108            language: "en-us".to_string(),
109            config: GenerationConfig::default(),
110        })
111    }
112
113    #[cfg(feature = "llama")]
114    pub fn infer_from_ipa(
115        &self,
116        input_ipa: &str,
117        ref_codes: &[i32],
118        ref_ipa: &str,
119    ) -> Result<Vec<f32>> {
120        let prompt = tokens::build_prompt(ref_ipa, input_ipa, ref_codes);
121        let generated = self
122            .backbone
123            .generate(&prompt, self.config.max_new_tokens)
124            .context("Backbone generation failed")?;
125
126        let speech_ids = tokens::extract_ids(&generated);
127        if speech_ids.is_empty() {
128            anyhow::bail!(
129                "No speech tokens in backbone output. Snippet: {:?}",
130                &generated[..generated.len().min(200)]
131            );
132        }
133
134        self.codec
135            .decode(&speech_ids)
136            .context("NeuCodec decode failed")
137    }
138
139    pub fn decode_tokens(&self, speech_ids: &[i32]) -> Result<Vec<f32>> {
140        self.codec
141            .decode(speech_ids)
142            .context("NeuCodec decode failed")
143    }
144}