rlx_neutts/backbone/
rlx.rs1use std::path::{Path, PathBuf};
23use std::sync::Mutex;
24
25use anyhow::{Context, Result, bail};
26use rlx_core::validate_standard_device;
27use rlx_llama_base::LlamaBaseConfig;
28use rlx_llama32::{Llama32Runner, Llama32RunnerBuilder};
29use rlx_qwen3::{SampleOpts, sample_token};
30use rlx_qwen35::{decode_ids_from_gguf, encode_prompt_from_gguf};
31use rlx_runtime::Device;
32
33use crate::tokens::STOP_TOKEN;
34
35fn env_truthy(name: &str) -> bool {
36 std::env::var(name)
37 .ok()
38 .is_some_and(|v| v == "1" || v.eq_ignore_ascii_case("true"))
39}
40
41pub const DEFAULT_N_CTX: u32 = 2048;
43
44pub struct BackboneModel {
46 runner: Mutex<Llama32Runner>,
47 weights: PathBuf,
48 n_ctx: u32,
49 pub seed: Option<u32>,
50 #[allow(dead_code)]
52 greedy_parity: bool,
53 _arch: String,
54}
55
56impl BackboneModel {
57 pub fn load(path: &Path, n_ctx: u32) -> Result<Self> {
58 Self::load_on(path, n_ctx, Device::Cpu)
59 }
60
61 pub fn load_on(path: &Path, n_ctx: u32, device: Device) -> Result<Self> {
63 Self::load_inner(path, n_ctx, true, device, false)
64 }
65
66 pub fn load_greedy_parity(path: &Path, n_ctx: u32) -> Result<Self> {
68 Self::load_greedy_parity_on(path, n_ctx, Device::Cpu)
69 }
70
71 pub fn load_greedy_parity_on(path: &Path, n_ctx: u32, device: Device) -> Result<Self> {
73 Self::load_inner(path, n_ctx, false, device, true)
74 }
75
76 fn load_inner(
77 path: &Path,
78 n_ctx: u32,
79 packed_weights: bool,
80 device: Device,
81 greedy_parity: bool,
82 ) -> Result<Self> {
83 validate_standard_device("neutts", device)?;
84 let base = LlamaBaseConfig::from_gguf_path(path)
85 .with_context(|| format!("parse GGUF {:?}", path))?;
86 if base.arch != "llama" {
88 bail!(
89 "rlx-neutts: expected `general.architecture = llama` in {}; got `{}`. \
90 Point at a NeuTTS / Llama-shaped GGUF.",
91 path.display(),
92 base.arch
93 );
94 }
95
96 let runner = Llama32RunnerBuilder::default()
97 .weights(path)
98 .max_seq(n_ctx as usize)
99 .device(device)
100 .packed_weights(packed_weights)
101 .sample(SampleOpts::greedy())
102 .build()
103 .context("build Llama32Runner for NeuTTS backbone")?;
104
105 eprintln!(
106 "[backbone/rlx-llama32] loaded {} (hidden={}, layers={})",
107 path.display(),
108 base.hidden_size,
109 base.num_hidden_layers
110 );
111
112 Ok(Self {
113 runner: Mutex::new(runner),
114 weights: path.to_path_buf(),
115 n_ctx,
116 seed: None,
117 greedy_parity,
118 _arch: base.arch,
119 })
120 }
121
122 fn sample_opts(&self) -> SampleOpts {
123 let seed = self.seed.map(u64::from).unwrap_or_else(rand::random);
124 SampleOpts::temperature(1.0, seed)
125 .with_top_k(50)
126 .with_top_p(0.9)
127 }
128
129 pub fn generate(&self, prompt: &str, max_new_tokens: u32) -> Result<String> {
130 let mut output = String::new();
131 self.generate_streaming(prompt, max_new_tokens, |piece| {
132 output.push_str(piece);
133 Ok(())
134 })?;
135 Ok(output)
136 }
137
138 pub fn generate_streaming<F>(
139 &self,
140 prompt: &str,
141 max_new_tokens: u32,
142 mut on_piece: F,
143 ) -> Result<()>
144 where
145 F: FnMut(&str) -> Result<()>,
146 {
147 let prompt_ids = encode_prompt_from_gguf(&self.weights, prompt)
148 .with_context(|| format!("tokenize prompt for {}", self.weights.display()))?;
149
150 eprintln!(
151 "[backbone/rlx-llama32] prompt token count: {} / n_ctx={}",
152 prompt_ids.len(),
153 self.n_ctx
154 );
155 if prompt_ids.len() as u32 > self.n_ctx {
156 bail!(
157 "Prompt too long: {} tokens exceeds n_ctx={}",
158 prompt_ids.len(),
159 self.n_ctx
160 );
161 }
162 if prompt_ids.is_empty() {
163 return Ok(());
164 }
165
166 let mut ids = prompt_ids;
167 let sample = self.sample_opts();
168 let mut runner = self
169 .runner
170 .lock()
171 .map_err(|e| anyhow::anyhow!("backbone runner lock poisoned: {e}"))?;
172
173 for _ in 0..max_new_tokens {
174 let logits = runner
175 .predict_logits(&ids)
176 .context("RLX backbone predict_logits failed")?;
177 let next = sample_token(&logits, sample) as u32;
178
179 let piece = decode_ids_from_gguf(&self.weights, std::slice::from_ref(&next), true)
180 .with_context(|| format!("decode token {next}"))?;
181
182 if piece.is_empty() {
183 ids.push(next);
184 continue;
185 }
186
187 if let Some(pos) = piece.find(STOP_TOKEN) {
188 let before = &piece[..pos];
189 if !before.is_empty() {
190 on_piece(before)?;
191 }
192 break;
193 }
194
195 on_piece(&piece)?;
196 ids.push(next);
197 }
198
199 Ok(())
200 }
201
202 pub fn generate_greedy_ids(&self, prompt: &str, max_new_tokens: u32) -> Result<Vec<u32>> {
204 let prompt_ids = encode_prompt_from_gguf(&self.weights, prompt)?;
205 self.generate_greedy_ids_from_prompt(&prompt_ids, max_new_tokens)
206 }
207
208 pub fn generate_greedy_ids_from_prompt(
214 &self,
215 prompt_ids: &[u32],
216 max_new_tokens: u32,
217 ) -> Result<Vec<u32>> {
218 let mut runner = self.runner.lock().map_err(|e| anyhow::anyhow!("{e}"))?;
219 let n = max_new_tokens as usize;
220 if env_truthy("NEUTTS_GREEDY_PREDICT_LOGITS") {
221 let opts = SampleOpts::greedy();
222 let mut history = prompt_ids.to_vec();
223 let mut out = Vec::with_capacity(n);
224 for _ in 0..n {
225 let logits = runner
226 .predict_logits(&history)
227 .context("greedy parity predict_logits")?;
228 let next = sample_token(&logits, opts) as u32;
229 out.push(next);
230 history.push(next);
231 }
232 return Ok(out);
233 }
234 runner.generate(prompt_ids, n, |_| {})
235 }
236
237 pub fn generate_greedy(&self, prompt: &str, max_new_tokens: u32) -> Result<String> {
239 let new_ids = self.generate_greedy_ids(prompt, max_new_tokens)?;
240 let mut out = String::new();
241 for &tok in &new_ids {
242 let piece = decode_ids_from_gguf(&self.weights, std::slice::from_ref(&tok), true)?;
243 if piece.find(STOP_TOKEN).is_some() {
244 break;
245 }
246 out.push_str(&piece);
247 }
248 Ok(out)
249 }
250}