Skip to main content

rlx_neutts/backbone/
rlx.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! NeuTTS GGUF backbone via [`rlx_llama32::Llama32Runner`] (rlx-models).
17//!
18//! Tokenisation uses the GGUF embedded vocab ([`rlx_qwen35::encode_prompt_from_gguf`]
19//! / [`rlx_qwen35::decode_ids_from_gguf`], same path as `rlx-llama32` CLI).
20//! Sampling matches the original NeuTTS defaults: top-k=50, top-p=0.9, temp=1.0.
21
22use std::path::{Path, PathBuf};
23use std::sync::Mutex;
24
25use anyhow::{Context, Result, bail};
26use rlx_core::validate_standard_device;
27use rlx_llama_base::LlamaBaseConfig;
28use rlx_llama32::{Llama32Runner, Llama32RunnerBuilder};
29use rlx_qwen3::{SampleOpts, sample_token};
30use rlx_qwen35::{decode_ids_from_gguf, encode_prompt_from_gguf};
31use rlx_runtime::Device;
32
33use crate::tokens::STOP_TOKEN;
34
35fn env_truthy(name: &str) -> bool {
36    std::env::var(name)
37        .ok()
38        .is_some_and(|v| v == "1" || v.eq_ignore_ascii_case("true"))
39}
40
41/// Default context window (must match Python's `max_context = 2048`).
42pub const DEFAULT_N_CTX: u32 = 2048;
43
44/// NeuTTS backbone — RLX Llama-3.2 runner over a llama-tagged GGUF.
45pub struct BackboneModel {
46    runner: Mutex<Llama32Runner>,
47    weights: PathBuf,
48    n_ctx: u32,
49    pub seed: Option<u32>,
50    /// When true, greedy parity uses incremental prefill+decode (llama.cpp-shaped).
51    #[allow(dead_code)]
52    greedy_parity: bool,
53    _arch: String,
54}
55
56impl BackboneModel {
57    pub fn load(path: &Path, n_ctx: u32) -> Result<Self> {
58        Self::load_on(path, n_ctx, Device::Cpu)
59    }
60
61    /// Load the GGUF backbone on a specific execution device.
62    pub fn load_on(path: &Path, n_ctx: u32, device: Device) -> Result<Self> {
63        Self::load_inner(path, n_ctx, true, device, false)
64    }
65
66    /// F32 dequant + incremental greedy (tail parity vs llama-cpp Q4).
67    pub fn load_greedy_parity(path: &Path, n_ctx: u32) -> Result<Self> {
68        Self::load_greedy_parity_on(path, n_ctx, Device::Cpu)
69    }
70
71    /// Greedy parity load on a specific execution device.
72    pub fn load_greedy_parity_on(path: &Path, n_ctx: u32, device: Device) -> Result<Self> {
73        Self::load_inner(path, n_ctx, false, device, true)
74    }
75
76    fn load_inner(
77        path: &Path,
78        n_ctx: u32,
79        packed_weights: bool,
80        device: Device,
81        greedy_parity: bool,
82    ) -> Result<Self> {
83        validate_standard_device("neutts", device)?;
84        let base = LlamaBaseConfig::from_gguf_path(path)
85            .with_context(|| format!("parse GGUF {:?}", path))?;
86        // NeuTTS Nano/Air GGUFs are llama-tagged (same layout as LLaMA 3.2 / Bonsai).
87        if base.arch != "llama" {
88            bail!(
89                "rlx-neutts: expected `general.architecture = llama` in {}; got `{}`. \
90                 Point at a NeuTTS / Llama-shaped GGUF.",
91                path.display(),
92                base.arch
93            );
94        }
95
96        let runner = Llama32RunnerBuilder::default()
97            .weights(path)
98            .max_seq(n_ctx as usize)
99            .device(device)
100            .packed_weights(packed_weights)
101            .sample(SampleOpts::greedy())
102            .build()
103            .context("build Llama32Runner for NeuTTS backbone")?;
104
105        eprintln!(
106            "[backbone/rlx-llama32] loaded {} (hidden={}, layers={})",
107            path.display(),
108            base.hidden_size,
109            base.num_hidden_layers
110        );
111
112        Ok(Self {
113            runner: Mutex::new(runner),
114            weights: path.to_path_buf(),
115            n_ctx,
116            seed: None,
117            greedy_parity,
118            _arch: base.arch,
119        })
120    }
121
122    fn sample_opts(&self) -> SampleOpts {
123        let seed = self.seed.map(u64::from).unwrap_or_else(rand::random);
124        SampleOpts::temperature(1.0, seed)
125            .with_top_k(50)
126            .with_top_p(0.9)
127    }
128
129    pub fn generate(&self, prompt: &str, max_new_tokens: u32) -> Result<String> {
130        let mut output = String::new();
131        self.generate_streaming(prompt, max_new_tokens, |piece| {
132            output.push_str(piece);
133            Ok(())
134        })?;
135        Ok(output)
136    }
137
138    pub fn generate_streaming<F>(
139        &self,
140        prompt: &str,
141        max_new_tokens: u32,
142        mut on_piece: F,
143    ) -> Result<()>
144    where
145        F: FnMut(&str) -> Result<()>,
146    {
147        let prompt_ids = encode_prompt_from_gguf(&self.weights, prompt)
148            .with_context(|| format!("tokenize prompt for {}", self.weights.display()))?;
149
150        eprintln!(
151            "[backbone/rlx-llama32] prompt token count: {} / n_ctx={}",
152            prompt_ids.len(),
153            self.n_ctx
154        );
155        if prompt_ids.len() as u32 > self.n_ctx {
156            bail!(
157                "Prompt too long: {} tokens exceeds n_ctx={}",
158                prompt_ids.len(),
159                self.n_ctx
160            );
161        }
162        if prompt_ids.is_empty() {
163            return Ok(());
164        }
165
166        let mut ids = prompt_ids;
167        let sample = self.sample_opts();
168        let mut runner = self
169            .runner
170            .lock()
171            .map_err(|e| anyhow::anyhow!("backbone runner lock poisoned: {e}"))?;
172
173        for _ in 0..max_new_tokens {
174            let logits = runner
175                .predict_logits(&ids)
176                .context("RLX backbone predict_logits failed")?;
177            let next = sample_token(&logits, sample) as u32;
178
179            let piece = decode_ids_from_gguf(&self.weights, std::slice::from_ref(&next), true)
180                .with_context(|| format!("decode token {next}"))?;
181
182            if piece.is_empty() {
183                ids.push(next);
184                continue;
185            }
186
187            if let Some(pos) = piece.find(STOP_TOKEN) {
188                let before = &piece[..pos];
189                if !before.is_empty() {
190                    on_piece(before)?;
191                }
192                break;
193            }
194
195            on_piece(&piece)?;
196            ids.push(next);
197        }
198
199        Ok(())
200    }
201
202    /// Greedy token IDs for parity tests (same GGUF vocab as production).
203    pub fn generate_greedy_ids(&self, prompt: &str, max_new_tokens: u32) -> Result<Vec<u32>> {
204        let prompt_ids = encode_prompt_from_gguf(&self.weights, prompt)?;
205        self.generate_greedy_ids_from_prompt(&prompt_ids, max_new_tokens)
206    }
207
208    /// Greedy continuation for parity tests.
209    ///
210    /// [`load_greedy_parity`] uses KV-cached [`Llama32Runner::generate`] (F32 weights,
211    /// MSVC uses oneshot decode in `step_cached`). Production [`load`] uses packed Q4.
212    /// Debug: `NEUTTS_GREEDY_INCREMENTAL=1` or `NEUTTS_GREEDY_PREDICT_LOGITS=1`.
213    pub fn generate_greedy_ids_from_prompt(
214        &self,
215        prompt_ids: &[u32],
216        max_new_tokens: u32,
217    ) -> Result<Vec<u32>> {
218        let mut runner = self.runner.lock().map_err(|e| anyhow::anyhow!("{e}"))?;
219        let n = max_new_tokens as usize;
220        if env_truthy("NEUTTS_GREEDY_PREDICT_LOGITS") {
221            let opts = SampleOpts::greedy();
222            let mut history = prompt_ids.to_vec();
223            let mut out = Vec::with_capacity(n);
224            for _ in 0..n {
225                let logits = runner
226                    .predict_logits(&history)
227                    .context("greedy parity predict_logits")?;
228                let next = sample_token(&logits, opts) as u32;
229                out.push(next);
230                history.push(next);
231            }
232            return Ok(out);
233        }
234        runner.generate(prompt_ids, n, |_| {})
235    }
236
237    /// Greedy text generation for parity tests (deterministic vs llama.cpp reference).
238    pub fn generate_greedy(&self, prompt: &str, max_new_tokens: u32) -> Result<String> {
239        let new_ids = self.generate_greedy_ids(prompt, max_new_tokens)?;
240        let mut out = String::new();
241        for &tok in &new_ids {
242            let piece = decode_ids_from_gguf(&self.weights, std::slice::from_ref(&tok), true)?;
243            if piece.find(STOP_TOKEN).is_some() {
244                break;
245            }
246            out.push_str(&piece);
247        }
248        Ok(out)
249    }
250}