1use crate::{Llama32Config, Llama32Generator, llama32_cfg_from_gguf};
17use anyhow::{Context, Result, anyhow, bail};
18use rlx_cli::{LmRunner, WeightFormat};
19use rlx_core::weight_loader::GgufLoader;
20use rlx_gguf::{GgufFile, MetaValue};
21use rlx_qwen3::SampleOpts;
22use rlx_runtime::{Device, Session};
23use std::path::{Path, PathBuf};
24
25pub type Llama32ConfigSource = rlx_runtime::ConfigSource<Llama32Config>;
37
38#[derive(Debug, Clone)]
39pub struct Llama32RunnerBuilder {
40 weights: Option<PathBuf>,
41 config: Option<Llama32ConfigSource>,
42 device: Option<Device>,
43 max_seq: Option<usize>,
44 max_memory_gb: Option<f32>,
45 stream: bool,
46 sample: Option<SampleOpts>,
47 format: Option<WeightFormat>,
48 packed_weights: Option<bool>,
50 bucketed_decode_cache: bool,
53}
54
55impl Default for Llama32RunnerBuilder {
56 fn default() -> Self {
57 Self {
58 weights: None,
59 config: None,
60 device: None,
61 max_seq: None,
62 max_memory_gb: None,
63 stream: true,
64 sample: None,
65 format: None,
66 packed_weights: None,
67 bucketed_decode_cache: true,
68 }
69 }
70}
71
72impl Llama32RunnerBuilder {
73 pub fn weights<P: Into<PathBuf>>(mut self, path: P) -> Self {
74 self.weights = Some(path.into());
75 self
76 }
77
78 pub fn format(mut self, fmt: WeightFormat) -> Self {
79 self.format = Some(fmt);
80 self
81 }
82
83 pub fn config(mut self, src: Llama32ConfigSource) -> Self {
84 self.config = Some(src);
85 self
86 }
87
88 pub fn config_value(self, cfg: Llama32Config) -> Self {
89 self.config(Llama32ConfigSource::Explicit(cfg))
90 }
91
92 pub fn device(mut self, d: Device) -> Self {
93 self.device = Some(d);
94 self
95 }
96
97 pub fn max_seq(mut self, n: usize) -> Self {
98 self.max_seq = Some(n);
99 self
100 }
101
102 pub fn max_memory_gb(mut self, gb: f32) -> Self {
103 self.max_memory_gb = Some(gb);
104 self
105 }
106
107 pub fn stream(mut self, on: bool) -> Self {
108 self.stream = on;
109 self
110 }
111
112 pub fn sample(mut self, opts: SampleOpts) -> Self {
113 self.sample = Some(opts);
114 self
115 }
116
117 pub fn packed_weights(mut self, on: bool) -> Self {
123 self.packed_weights = Some(on);
124 self
125 }
126
127 pub fn bucketed_decode_cache(mut self, on: bool) -> Self {
129 self.bucketed_decode_cache = on;
130 self
131 }
132
133 pub fn build(self) -> Result<Llama32Runner> {
134 let weights_path = self
135 .weights
136 .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?;
137 let format = match self.format {
138 Some(f) => f,
139 None => WeightFormat::from_path(&weights_path)?,
140 };
141 let device = self.device.unwrap_or(Device::Cpu);
142 let max_seq = self.max_seq.unwrap_or(128);
143 let stream = self.stream;
144 let sample = self.sample.unwrap_or_else(SampleOpts::greedy);
145
146 let (cfg, total_bytes_estimate) = match format {
147 WeightFormat::Gguf => load_llama32_gguf_config(&weights_path, self.config.as_ref())?,
148 WeightFormat::Safetensors => {
149 load_llama32_safetensors_config(&weights_path, self.config.as_ref())?
150 }
151 };
152
153 if let Some(cap_gb) = self.max_memory_gb {
154 let est_gb = total_bytes_estimate as f32 / (1024.0 * 1024.0 * 1024.0);
155 if est_gb > cap_gb {
156 bail!(
157 "weights would dequant to ~{est_gb:.1} GB at F32, exceeds cap {cap_gb:.1} GB"
158 );
159 }
160 }
161
162 let use_packed = self.packed_weights.unwrap_or_else(|| {
163 matches!(format, WeightFormat::Gguf)
164 && std::fs::metadata(&weights_path)
165 .map(|m| m.len() >= 256 * 1024 * 1024)
166 .unwrap_or(false)
167 });
168
169 crate::validate_device(&cfg, device, use_packed)?;
170
171 let path_str = weights_path
172 .to_str()
173 .ok_or_else(|| anyhow!("non-utf8 weights path"))?;
174 let generator = if use_packed {
175 None
176 } else {
177 let mut loader = rlx_core::weight_loader::load_from_path(path_str)?;
178 let mut generator = Llama32Generator::from_loader_at(
179 cfg.clone(),
180 loader.as_mut(),
181 device,
182 &weights_path,
183 )?
184 .with_compile_seq_cap(max_seq)
185 .with_prefill_cache(8);
186 if self.bucketed_decode_cache {
187 generator = generator.with_decode_cache(max_seq.saturating_add(16).max(64));
188 }
189 Some(generator)
190 };
191
192 let packed = if use_packed {
193 if !matches!(format, WeightFormat::Gguf) {
194 bail!(
195 "packed_weights(true) requires a .gguf file; got {:?} for {:?}",
196 format,
197 weights_path
198 );
199 }
200 eprintln!(
201 "[llama32-runner] packed_weights=true — compiling prefill graph with \
202 Op::DequantMatMul on {device:?}"
203 );
204 Some(Llama32PackedForward::build(
205 &cfg,
206 &weights_path,
207 max_seq,
208 device,
209 )?)
210 } else {
211 None
212 };
213
214 Ok(Llama32Runner {
215 generator,
216 cfg,
217 sample,
218 stream,
219 device,
220 packed,
221 })
222 }
223}
224
225struct Llama32PackedForward {
226 compiled: rlx_runtime::CompiledGraph,
227 seq: usize,
228 padded_ids: Vec<u32>,
229 ids_f32: Vec<f32>,
230 last_idx: [f32; 1],
231}
232
233impl Llama32PackedForward {
234 fn build(cfg: &Llama32Config, weights_path: &Path, seq: usize, device: Device) -> Result<Self> {
235 use crate::build_llama32_graph_sized_packed;
236 let exec_device = rlx_core::flow_bridge::packed_gguf_execution_device(device);
237 if exec_device != device {
238 eprintln!(
239 "[llama32-runner] packed GGUF on {device:?}: prefill executes on {exec_device:?} \
240 until {device:?} packed parity is fixed upstream"
241 );
242 }
243 let mut loader = GgufLoader::from_file(
244 weights_path
245 .to_str()
246 .ok_or_else(|| anyhow!("non-utf8 weights path"))?,
247 )?;
248 let mut packed = std::collections::HashMap::new();
249 let (graph, params) =
250 build_llama32_graph_sized_packed(cfg, &mut loader, 1, seq, true, true, &mut packed)?;
251 let opts = rlx_core::flow_bridge::compile_options_for_packed_gguf_prefill(exec_device);
252 let mut compiled = rlx_core::flow_bridge::packed_gguf_compile_guard(exec_device, || {
253 Session::new(exec_device).compile_with(graph, &opts)
254 });
255 for (name, data) in ¶ms {
256 compiled.set_param(name, data);
257 }
258 for (name, (bytes, _scheme, _shape)) in &packed {
259 compiled.set_param_typed(name, bytes, rlx_ir::DType::U8);
260 }
261 Ok(Self {
262 compiled,
263 seq,
264 padded_ids: vec![0u32; seq],
265 ids_f32: vec![0f32; seq],
266 last_idx: [0f32; 1],
267 })
268 }
269}
270
271pub struct Llama32Runner {
272 generator: Option<Llama32Generator>,
273 cfg: Llama32Config,
274 sample: SampleOpts,
275 stream: bool,
276 device: Device,
277 packed: Option<Llama32PackedForward>,
278}
279
280impl Llama32Runner {
281 pub fn builder() -> Llama32RunnerBuilder {
282 Llama32RunnerBuilder::default()
283 }
284
285 pub fn config(&self) -> &Llama32Config {
286 &self.cfg
287 }
288
289 pub fn device(&self) -> Device {
290 self.device
291 }
292
293 pub fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
295 if let Some(p) = self.packed.as_mut() {
296 let n = prompt_ids.len().min(p.seq);
297 p.padded_ids.fill(0);
298 for (i, &t) in prompt_ids.iter().take(n).enumerate() {
299 p.padded_ids[i] = t;
300 }
301 for (dst, &id) in p.ids_f32.iter_mut().zip(p.padded_ids.iter()) {
302 *dst = id as f32;
303 }
304 p.last_idx[0] = n.saturating_sub(1) as f32;
305 let exec_device = p.compiled.device();
306 let out = rlx_core::run_packed_prefill(
307 &mut p.compiled,
308 exec_device,
309 n,
310 p.seq,
311 &[
312 ("input_ids", p.ids_f32.as_slice()),
313 ("last_token_idx", p.last_idx.as_slice()),
314 ],
315 );
316 let logits = out
317 .into_iter()
318 .next()
319 .ok_or_else(|| anyhow!("packed forward returned no output"))?;
320 let vocab = self.cfg.vocab_size;
321 if logits.len() < vocab {
322 bail!("logits short: {} < {vocab}", logits.len());
323 }
324 return Ok(logits[..vocab].to_vec());
325 }
326 let generator = self
327 .generator
328 .as_mut()
329 .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
330 generator.prefill_get_last_logits(prompt_ids)
331 }
332
333 pub fn generate_packed(
334 &mut self,
335 prompt_ids: &[u32],
336 n_new: usize,
337 mut on_token: impl FnMut(u32),
338 ) -> Result<Vec<u32>> {
339 if self.packed.is_none() {
340 bail!("generate_packed() only works in packed_weights(true) mode");
341 }
342 let mut history: Vec<u32> = prompt_ids.to_vec();
343 let mut out = Vec::with_capacity(n_new);
344 for _ in 0..n_new {
345 let logits = self.predict_logits(&history)?;
346 let next = rlx_qwen3::sample_token(&logits, self.sample) as u32;
347 on_token(next);
348 history.push(next);
349 out.push(next);
350 }
351 Ok(out)
352 }
353
354 pub fn generate(
355 &mut self,
356 prompt_ids: &[u32],
357 n_new: usize,
358 mut on_token: impl FnMut(u32),
359 ) -> Result<Vec<u32>> {
360 if self.packed.is_some() {
361 return self.generate_packed(prompt_ids, n_new, on_token);
362 }
363 let generator = self
364 .generator
365 .as_mut()
366 .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
367 generator.prefill(prompt_ids);
368 let tokens = if self.stream {
369 generator.generate_cached_with(n_new, self.sample, &mut on_token)?
370 } else {
371 let toks = generator.generate_cached(n_new, self.sample)?;
372 for &t in &toks {
373 on_token(t);
374 }
375 toks
376 };
377 Ok(tokens)
378 }
379}
380
381impl LmRunner for Llama32Runner {
382 fn family(&self) -> &'static str {
383 "llama32"
384 }
385 fn vocab_size(&self) -> usize {
386 self.config().vocab_size
387 }
388 fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
389 Llama32Runner::predict_logits(self, prompt_ids)
390 }
391 fn generate(
392 &mut self,
393 prompt_ids: &[u32],
394 n_new: usize,
395 on_token: &mut dyn FnMut(u32) -> bool,
396 ) -> Result<Vec<u32>> {
397 Llama32Runner::generate(self, prompt_ids, n_new, |tok| {
398 let _ = on_token(tok);
399 })
400 }
401}
402
403fn load_llama32_gguf_config(
404 path: &Path,
405 override_src: Option<&Llama32ConfigSource>,
406) -> Result<(Llama32Config, u64)> {
407 let raw = GgufFile::from_path(path).with_context(|| format!("opening {path:?}"))?;
408 let arch = raw
409 .metadata
410 .get("general.architecture")
411 .and_then(MetaValue::as_str)
412 .unwrap_or("llama");
413 if arch != "llama" {
414 bail!(
415 "{path:?} has architecture {arch:?}; Llama32Runner expects general.architecture=llama"
416 );
417 }
418 let cfg = match override_src {
419 Some(Llama32ConfigSource::Explicit(c)) => c.clone(),
420 Some(Llama32ConfigSource::JsonFile(p)) => {
421 Llama32Config::from_file(p).with_context(|| format!("reading override config {p:?}"))?
422 }
423 Some(Llama32ConfigSource::Embedded) | None => llama32_cfg_from_gguf(&raw)?,
424 };
425 let bytes_est: u64 = raw
426 .tensors
427 .values()
428 .map(|t| (t.n_elements() as u64) * 4)
429 .sum();
430 Ok((cfg, bytes_est))
431}
432
433fn load_llama32_safetensors_config(
434 path: &Path,
435 override_src: Option<&Llama32ConfigSource>,
436) -> Result<(Llama32Config, u64)> {
437 let cfg_path = match override_src {
438 Some(Llama32ConfigSource::Explicit(c)) => {
439 return Ok((c.clone(), default_st_size_estimate(path)));
440 }
441 Some(Llama32ConfigSource::JsonFile(p)) => p.clone(),
442 Some(Llama32ConfigSource::Embedded) => {
443 bail!("ConfigSource::Embedded only valid for GGUF; pass JsonFile for safetensors")
444 }
445 None => path
446 .parent()
447 .ok_or_else(|| anyhow!("weights path has no parent dir"))?
448 .join("config.json"),
449 };
450 let cfg = Llama32Config::from_file(&cfg_path)
451 .with_context(|| format!("reading config {cfg_path:?}"))?;
452 Ok((cfg, default_st_size_estimate(path)))
453}
454
455fn default_st_size_estimate(path: &Path) -> u64 {
456 std::fs::metadata(path).map(|m| m.len()).unwrap_or(0)
457}