Skip to main content

rlx_neutts/backbone/
rlx.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! NeuTTS GGUF backbone via [`rlx_llama32::Llama32Runner`] (rlx-models).
17//!
18//! Tokenisation uses the GGUF embedded vocab ([`rlx_qwen35::encode_prompt_from_gguf`]
19//! / [`rlx_qwen35::decode_ids_from_gguf`], same path as `rlx-llama32` CLI).
20//! Sampling matches the original NeuTTS defaults: top-k=50, top-p=0.9, temp=1.0.
21
22use std::path::{Path, PathBuf};
23use std::sync::Mutex;
24
25use anyhow::{Context, Result, bail};
26use rlx_llama_base::LlamaBaseConfig;
27use rlx_llama32::{Llama32Runner, Llama32RunnerBuilder};
28use rlx_qwen3::{SampleOpts, sample_token};
29use rlx_qwen35::{decode_ids_from_gguf, encode_prompt_from_gguf};
30use rlx_runtime::Device;
31
32use crate::tokens::STOP_TOKEN;
33
34fn env_truthy(name: &str) -> bool {
35    std::env::var(name)
36        .ok()
37        .is_some_and(|v| v == "1" || v.eq_ignore_ascii_case("true"))
38}
39
40/// Default context window (must match Python's `max_context = 2048`).
41pub const DEFAULT_N_CTX: u32 = 2048;
42
43/// NeuTTS backbone — RLX Llama-3.2 runner over a llama-tagged GGUF.
44pub struct BackboneModel {
45    runner: Mutex<Llama32Runner>,
46    weights: PathBuf,
47    n_ctx: u32,
48    pub seed: Option<u32>,
49    /// When true, greedy parity uses incremental prefill+decode (llama.cpp-shaped).
50    #[allow(dead_code)]
51    greedy_parity: bool,
52    _arch: String,
53}
54
55impl BackboneModel {
56    pub fn load(path: &Path, n_ctx: u32) -> Result<Self> {
57        Self::load_inner(path, n_ctx, true, Device::Cpu, false)
58    }
59
60    /// F32 dequant + incremental greedy (tail parity vs llama-cpp Q4).
61    pub fn load_greedy_parity(path: &Path, n_ctx: u32) -> Result<Self> {
62        Self::load_inner(path, n_ctx, false, Device::Cpu, true)
63    }
64
65    fn load_inner(
66        path: &Path,
67        n_ctx: u32,
68        packed_weights: bool,
69        device: Device,
70        greedy_parity: bool,
71    ) -> Result<Self> {
72        let base = LlamaBaseConfig::from_gguf_path(path)
73            .with_context(|| format!("parse GGUF {:?}", path))?;
74        // NeuTTS Nano/Air GGUFs are llama-tagged (same layout as LLaMA 3.2 / Bonsai).
75        if base.arch != "llama" {
76            bail!(
77                "rlx-neutts: expected `general.architecture = llama` in {}; got `{}`. \
78                 Point at a NeuTTS / Llama-shaped GGUF.",
79                path.display(),
80                base.arch
81            );
82        }
83
84        let runner = Llama32RunnerBuilder::default()
85            .weights(path)
86            .max_seq(n_ctx as usize)
87            .device(device)
88            .packed_weights(packed_weights)
89            .sample(SampleOpts::greedy())
90            .build()
91            .context("build Llama32Runner for NeuTTS backbone")?;
92
93        eprintln!(
94            "[backbone/rlx-llama32] loaded {} (hidden={}, layers={})",
95            path.display(),
96            base.hidden_size,
97            base.num_hidden_layers
98        );
99
100        Ok(Self {
101            runner: Mutex::new(runner),
102            weights: path.to_path_buf(),
103            n_ctx,
104            seed: None,
105            greedy_parity,
106            _arch: base.arch,
107        })
108    }
109
110    fn sample_opts(&self) -> SampleOpts {
111        let seed = self.seed.map(u64::from).unwrap_or_else(rand::random);
112        SampleOpts::temperature(1.0, seed)
113            .with_top_k(50)
114            .with_top_p(0.9)
115    }
116
117    pub fn generate(&self, prompt: &str, max_new_tokens: u32) -> Result<String> {
118        let mut output = String::new();
119        self.generate_streaming(prompt, max_new_tokens, |piece| {
120            output.push_str(piece);
121            Ok(())
122        })?;
123        Ok(output)
124    }
125
126    pub fn generate_streaming<F>(
127        &self,
128        prompt: &str,
129        max_new_tokens: u32,
130        mut on_piece: F,
131    ) -> Result<()>
132    where
133        F: FnMut(&str) -> Result<()>,
134    {
135        let prompt_ids = encode_prompt_from_gguf(&self.weights, prompt)
136            .with_context(|| format!("tokenize prompt for {}", self.weights.display()))?;
137
138        eprintln!(
139            "[backbone/rlx-llama32] prompt token count: {} / n_ctx={}",
140            prompt_ids.len(),
141            self.n_ctx
142        );
143        if prompt_ids.len() as u32 > self.n_ctx {
144            bail!(
145                "Prompt too long: {} tokens exceeds n_ctx={}",
146                prompt_ids.len(),
147                self.n_ctx
148            );
149        }
150        if prompt_ids.is_empty() {
151            return Ok(());
152        }
153
154        let mut ids = prompt_ids;
155        let sample = self.sample_opts();
156        let mut runner = self
157            .runner
158            .lock()
159            .map_err(|e| anyhow::anyhow!("backbone runner lock poisoned: {e}"))?;
160
161        for _ in 0..max_new_tokens {
162            let logits = runner
163                .predict_logits(&ids)
164                .context("RLX backbone predict_logits failed")?;
165            let next = sample_token(&logits, sample) as u32;
166
167            let piece = decode_ids_from_gguf(&self.weights, std::slice::from_ref(&next), true)
168                .with_context(|| format!("decode token {next}"))?;
169
170            if piece.is_empty() {
171                ids.push(next);
172                continue;
173            }
174
175            if let Some(pos) = piece.find(STOP_TOKEN) {
176                let before = &piece[..pos];
177                if !before.is_empty() {
178                    on_piece(before)?;
179                }
180                break;
181            }
182
183            on_piece(&piece)?;
184            ids.push(next);
185        }
186
187        Ok(())
188    }
189
190    /// Greedy token IDs for parity tests (same GGUF vocab as production).
191    pub fn generate_greedy_ids(&self, prompt: &str, max_new_tokens: u32) -> Result<Vec<u32>> {
192        let prompt_ids = encode_prompt_from_gguf(&self.weights, prompt)?;
193        self.generate_greedy_ids_from_prompt(&prompt_ids, max_new_tokens)
194    }
195
196    /// Greedy continuation for parity tests.
197    ///
198    /// [`load_greedy_parity`] uses KV-cached [`Llama32Runner::generate`] (F32 weights,
199    /// MSVC uses oneshot decode in `step_cached`). Production [`load`] uses packed Q4.
200    /// Debug: `NEUTTS_GREEDY_INCREMENTAL=1` or `NEUTTS_GREEDY_PREDICT_LOGITS=1`.
201    pub fn generate_greedy_ids_from_prompt(
202        &self,
203        prompt_ids: &[u32],
204        max_new_tokens: u32,
205    ) -> Result<Vec<u32>> {
206        let mut runner = self.runner.lock().map_err(|e| anyhow::anyhow!("{e}"))?;
207        let n = max_new_tokens as usize;
208        if env_truthy("NEUTTS_GREEDY_PREDICT_LOGITS") {
209            let opts = SampleOpts::greedy();
210            let mut history = prompt_ids.to_vec();
211            let mut out = Vec::with_capacity(n);
212            for _ in 0..n {
213                let logits = runner
214                    .predict_logits(&history)
215                    .context("greedy parity predict_logits")?;
216                let next = sample_token(&logits, opts) as u32;
217                out.push(next);
218                history.push(next);
219            }
220            return Ok(out);
221        }
222        runner.generate(prompt_ids, n, |_| {})
223    }
224
225    /// Greedy text generation for parity tests (deterministic vs llama.cpp reference).
226    pub fn generate_greedy(&self, prompt: &str, max_new_tokens: u32) -> Result<String> {
227        let new_ids = self.generate_greedy_ids(prompt, max_new_tokens)?;
228        let mut out = String::new();
229        for &tok in &new_ids {
230            let piece = decode_ids_from_gguf(&self.weights, std::slice::from_ref(&tok), true)?;
231            if piece.find(STOP_TOKEN).is_some() {
232                break;
233            }
234            out.push_str(&piece);
235        }
236        Ok(out)
237    }
238}