rlx_neutts/backbone/
rlx.rs1use std::path::{Path, PathBuf};
23use std::sync::Mutex;
24
25use anyhow::{Context, Result, bail};
26use rlx_llama_base::LlamaBaseConfig;
27use rlx_llama32::{Llama32Runner, Llama32RunnerBuilder};
28use rlx_qwen3::{SampleOpts, sample_token};
29use rlx_qwen35::{decode_ids_from_gguf, encode_prompt_from_gguf};
30use rlx_runtime::Device;
31
32use crate::tokens::STOP_TOKEN;
33
34fn env_truthy(name: &str) -> bool {
35 std::env::var(name)
36 .ok()
37 .is_some_and(|v| v == "1" || v.eq_ignore_ascii_case("true"))
38}
39
40pub const DEFAULT_N_CTX: u32 = 2048;
42
43pub struct BackboneModel {
45 runner: Mutex<Llama32Runner>,
46 weights: PathBuf,
47 n_ctx: u32,
48 pub seed: Option<u32>,
49 #[allow(dead_code)]
51 greedy_parity: bool,
52 _arch: String,
53}
54
55impl BackboneModel {
56 pub fn load(path: &Path, n_ctx: u32) -> Result<Self> {
57 Self::load_inner(path, n_ctx, true, Device::Cpu, false)
58 }
59
60 pub fn load_greedy_parity(path: &Path, n_ctx: u32) -> Result<Self> {
62 Self::load_inner(path, n_ctx, false, Device::Cpu, true)
63 }
64
65 fn load_inner(
66 path: &Path,
67 n_ctx: u32,
68 packed_weights: bool,
69 device: Device,
70 greedy_parity: bool,
71 ) -> Result<Self> {
72 let base = LlamaBaseConfig::from_gguf_path(path)
73 .with_context(|| format!("parse GGUF {:?}", path))?;
74 if base.arch != "llama" {
76 bail!(
77 "rlx-neutts: expected `general.architecture = llama` in {}; got `{}`. \
78 Point at a NeuTTS / Llama-shaped GGUF.",
79 path.display(),
80 base.arch
81 );
82 }
83
84 let runner = Llama32RunnerBuilder::default()
85 .weights(path)
86 .max_seq(n_ctx as usize)
87 .device(device)
88 .packed_weights(packed_weights)
89 .sample(SampleOpts::greedy())
90 .build()
91 .context("build Llama32Runner for NeuTTS backbone")?;
92
93 eprintln!(
94 "[backbone/rlx-llama32] loaded {} (hidden={}, layers={})",
95 path.display(),
96 base.hidden_size,
97 base.num_hidden_layers
98 );
99
100 Ok(Self {
101 runner: Mutex::new(runner),
102 weights: path.to_path_buf(),
103 n_ctx,
104 seed: None,
105 greedy_parity,
106 _arch: base.arch,
107 })
108 }
109
110 fn sample_opts(&self) -> SampleOpts {
111 let seed = self.seed.map(u64::from).unwrap_or_else(rand::random);
112 SampleOpts::temperature(1.0, seed)
113 .with_top_k(50)
114 .with_top_p(0.9)
115 }
116
117 pub fn generate(&self, prompt: &str, max_new_tokens: u32) -> Result<String> {
118 let mut output = String::new();
119 self.generate_streaming(prompt, max_new_tokens, |piece| {
120 output.push_str(piece);
121 Ok(())
122 })?;
123 Ok(output)
124 }
125
126 pub fn generate_streaming<F>(
127 &self,
128 prompt: &str,
129 max_new_tokens: u32,
130 mut on_piece: F,
131 ) -> Result<()>
132 where
133 F: FnMut(&str) -> Result<()>,
134 {
135 let prompt_ids = encode_prompt_from_gguf(&self.weights, prompt)
136 .with_context(|| format!("tokenize prompt for {}", self.weights.display()))?;
137
138 eprintln!(
139 "[backbone/rlx-llama32] prompt token count: {} / n_ctx={}",
140 prompt_ids.len(),
141 self.n_ctx
142 );
143 if prompt_ids.len() as u32 > self.n_ctx {
144 bail!(
145 "Prompt too long: {} tokens exceeds n_ctx={}",
146 prompt_ids.len(),
147 self.n_ctx
148 );
149 }
150 if prompt_ids.is_empty() {
151 return Ok(());
152 }
153
154 let mut ids = prompt_ids;
155 let sample = self.sample_opts();
156 let mut runner = self
157 .runner
158 .lock()
159 .map_err(|e| anyhow::anyhow!("backbone runner lock poisoned: {e}"))?;
160
161 for _ in 0..max_new_tokens {
162 let logits = runner
163 .predict_logits(&ids)
164 .context("RLX backbone predict_logits failed")?;
165 let next = sample_token(&logits, sample) as u32;
166
167 let piece = decode_ids_from_gguf(&self.weights, std::slice::from_ref(&next), true)
168 .with_context(|| format!("decode token {next}"))?;
169
170 if piece.is_empty() {
171 ids.push(next);
172 continue;
173 }
174
175 if let Some(pos) = piece.find(STOP_TOKEN) {
176 let before = &piece[..pos];
177 if !before.is_empty() {
178 on_piece(before)?;
179 }
180 break;
181 }
182
183 on_piece(&piece)?;
184 ids.push(next);
185 }
186
187 Ok(())
188 }
189
190 pub fn generate_greedy_ids(&self, prompt: &str, max_new_tokens: u32) -> Result<Vec<u32>> {
192 let prompt_ids = encode_prompt_from_gguf(&self.weights, prompt)?;
193 self.generate_greedy_ids_from_prompt(&prompt_ids, max_new_tokens)
194 }
195
196 pub fn generate_greedy_ids_from_prompt(
202 &self,
203 prompt_ids: &[u32],
204 max_new_tokens: u32,
205 ) -> Result<Vec<u32>> {
206 let mut runner = self.runner.lock().map_err(|e| anyhow::anyhow!("{e}"))?;
207 let n = max_new_tokens as usize;
208 if env_truthy("NEUTTS_GREEDY_PREDICT_LOGITS") {
209 let opts = SampleOpts::greedy();
210 let mut history = prompt_ids.to_vec();
211 let mut out = Vec::with_capacity(n);
212 for _ in 0..n {
213 let logits = runner
214 .predict_logits(&history)
215 .context("greedy parity predict_logits")?;
216 let next = sample_token(&logits, opts) as u32;
217 out.push(next);
218 history.push(next);
219 }
220 return Ok(out);
221 }
222 runner.generate(prompt_ids, n, |_| {})
223 }
224
225 pub fn generate_greedy(&self, prompt: &str, max_new_tokens: u32) -> Result<String> {
227 let new_ids = self.generate_greedy_ids(prompt, max_new_tokens)?;
228 let mut out = String::new();
229 for &tok in &new_ids {
230 let piece = decode_ids_from_gguf(&self.weights, std::slice::from_ref(&tok), true)?;
231 if piece.find(STOP_TOKEN).is_some() {
232 break;
233 }
234 out.push_str(&piece);
235 }
236 Ok(out)
237 }
238}