Skip to main content

rlx_llama32/
runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use crate::{Llama32Config, Llama32Generator, llama32_cfg_from_gguf};
17use anyhow::{Context, Result, anyhow, bail};
18use rlx_cli::{LmRunner, WeightFormat};
19use rlx_core::weight_loader::GgufLoader;
20use rlx_gguf::{GgufFile, MetaValue};
21use rlx_qwen3::SampleOpts;
22use rlx_runtime::{Device, Session};
23use std::path::{Path, PathBuf};
24
25// ────────────────────────────────────────────────────────────────
26// LLaMA-3.2 runner — Meta Llama 3.x small LMs (1B / 3B).
27// ────────────────────────────────────────────────────────────────
28
29/// Where to load the Llama 3.2 config from.
30///
31/// Type alias of the shared `rlx_runtime::ConfigSource<T>` so the
32/// per-family `*ConfigSource` enums no longer duplicate the same
33/// `Embedded | JsonFile | Explicit(T)` shape. The variant constructors
34/// (`Llama32ConfigSource::Embedded`, etc.) keep working because
35/// type-alias resolution expands the path through the generic enum.
36pub type Llama32ConfigSource = rlx_runtime::ConfigSource<Llama32Config>;
37
38#[derive(Debug, Clone)]
39pub struct Llama32RunnerBuilder {
40    weights: Option<PathBuf>,
41    config: Option<Llama32ConfigSource>,
42    device: Option<Device>,
43    max_seq: Option<usize>,
44    max_memory_gb: Option<f32>,
45    stream: bool,
46    sample: Option<SampleOpts>,
47    format: Option<WeightFormat>,
48    /// `None` = auto (packed when GGUF ≥ 256 MB). `Some(_)` is an explicit override.
49    packed_weights: Option<bool>,
50    /// When false, decode uses one-shot graphs (slower compile, but
51    /// avoids bucketed-cache edge cases on some GPU backends).
52    bucketed_decode_cache: bool,
53}
54
55impl Default for Llama32RunnerBuilder {
56    fn default() -> Self {
57        Self {
58            weights: None,
59            config: None,
60            device: None,
61            max_seq: None,
62            max_memory_gb: None,
63            stream: true,
64            sample: None,
65            format: None,
66            packed_weights: None,
67            bucketed_decode_cache: true,
68        }
69    }
70}
71
72impl Llama32RunnerBuilder {
73    pub fn weights<P: Into<PathBuf>>(mut self, path: P) -> Self {
74        self.weights = Some(path.into());
75        self
76    }
77
78    pub fn format(mut self, fmt: WeightFormat) -> Self {
79        self.format = Some(fmt);
80        self
81    }
82
83    pub fn config(mut self, src: Llama32ConfigSource) -> Self {
84        self.config = Some(src);
85        self
86    }
87
88    pub fn config_value(self, cfg: Llama32Config) -> Self {
89        self.config(Llama32ConfigSource::Explicit(cfg))
90    }
91
92    pub fn device(mut self, d: Device) -> Self {
93        self.device = Some(d);
94        self
95    }
96
97    pub fn max_seq(mut self, n: usize) -> Self {
98        self.max_seq = Some(n);
99        self
100    }
101
102    pub fn max_memory_gb(mut self, gb: f32) -> Self {
103        self.max_memory_gb = Some(gb);
104        self
105    }
106
107    pub fn stream(mut self, on: bool) -> Self {
108        self.stream = on;
109        self
110    }
111
112    pub fn sample(mut self, opts: SampleOpts) -> Self {
113        self.sample = Some(opts);
114        self
115    }
116
117    /// Keep K-quant weights packed in the arena (`Op::DequantMatMul`).
118    /// GGUF only. Supported on CPU, Metal, and MLX.
119    ///
120    /// When unset, large GGUF files (≥ 256 MB on disk) auto-enable packed
121    /// prefill to avoid F32-dequant host memory blowups.
122    pub fn packed_weights(mut self, on: bool) -> Self {
123        self.packed_weights = Some(on);
124        self
125    }
126
127    /// Enable bucketed decode compile cache (default: true).
128    pub fn bucketed_decode_cache(mut self, on: bool) -> Self {
129        self.bucketed_decode_cache = on;
130        self
131    }
132
133    pub fn build(self) -> Result<Llama32Runner> {
134        let weights_path = self
135            .weights
136            .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?;
137        let format = match self.format {
138            Some(f) => f,
139            None => WeightFormat::from_path(&weights_path)?,
140        };
141        let device = self.device.unwrap_or(Device::Cpu);
142        let max_seq = self.max_seq.unwrap_or(128);
143        let stream = self.stream;
144        let sample = self.sample.unwrap_or_else(SampleOpts::greedy);
145
146        let (cfg, total_bytes_estimate) = match format {
147            WeightFormat::Gguf => load_llama32_gguf_config(&weights_path, self.config.as_ref())?,
148            WeightFormat::Safetensors => {
149                load_llama32_safetensors_config(&weights_path, self.config.as_ref())?
150            }
151        };
152
153        if let Some(cap_gb) = self.max_memory_gb {
154            let est_gb = total_bytes_estimate as f32 / (1024.0 * 1024.0 * 1024.0);
155            if est_gb > cap_gb {
156                bail!(
157                    "weights would dequant to ~{est_gb:.1} GB at F32, exceeds cap {cap_gb:.1} GB"
158                );
159            }
160        }
161
162        let use_packed = self.packed_weights.unwrap_or_else(|| {
163            matches!(format, WeightFormat::Gguf)
164                && std::fs::metadata(&weights_path)
165                    .map(|m| m.len() >= 256 * 1024 * 1024)
166                    .unwrap_or(false)
167        });
168
169        crate::validate_device(&cfg, device, use_packed)?;
170
171        let path_str = weights_path
172            .to_str()
173            .ok_or_else(|| anyhow!("non-utf8 weights path"))?;
174        let generator = if use_packed {
175            None
176        } else {
177            let mut loader = rlx_core::weight_loader::load_from_path(path_str)?;
178            let mut generator = Llama32Generator::from_loader_at(
179                cfg.clone(),
180                loader.as_mut(),
181                device,
182                &weights_path,
183            )?
184            .with_compile_seq_cap(max_seq)
185            .with_prefill_cache(8);
186            if self.bucketed_decode_cache {
187                generator = generator.with_decode_cache(max_seq.saturating_add(16).max(64));
188            }
189            Some(generator)
190        };
191
192        let packed = if use_packed {
193            if !matches!(format, WeightFormat::Gguf) {
194                bail!(
195                    "packed_weights(true) requires a .gguf file; got {:?} for {:?}",
196                    format,
197                    weights_path
198                );
199            }
200            eprintln!(
201                "[llama32-runner] packed_weights=true — compiling prefill graph with \
202                 Op::DequantMatMul on {device:?}"
203            );
204            Some(Llama32PackedForward::build(
205                &cfg,
206                &weights_path,
207                max_seq,
208                device,
209            )?)
210        } else {
211            None
212        };
213
214        Ok(Llama32Runner {
215            generator,
216            cfg,
217            sample,
218            stream,
219            device,
220            packed,
221        })
222    }
223}
224
225struct Llama32PackedForward {
226    compiled: rlx_runtime::CompiledGraph,
227    seq: usize,
228    padded_ids: Vec<u32>,
229    ids_f32: Vec<f32>,
230    last_idx: [f32; 1],
231}
232
233impl Llama32PackedForward {
234    fn build(cfg: &Llama32Config, weights_path: &Path, seq: usize, device: Device) -> Result<Self> {
235        use crate::build_llama32_graph_sized_packed;
236        let exec_device = rlx_core::flow_bridge::packed_gguf_execution_device(device);
237        if exec_device != device {
238            eprintln!(
239                "[llama32-runner] packed GGUF on {device:?}: prefill executes on {exec_device:?} \
240                 until {device:?} packed parity is fixed upstream"
241            );
242        }
243        let mut loader = GgufLoader::from_file(
244            weights_path
245                .to_str()
246                .ok_or_else(|| anyhow!("non-utf8 weights path"))?,
247        )?;
248        let mut packed = std::collections::HashMap::new();
249        let (graph, params) =
250            build_llama32_graph_sized_packed(cfg, &mut loader, 1, seq, true, true, &mut packed)?;
251        let opts = rlx_core::flow_bridge::compile_options_for_packed_gguf_prefill(exec_device);
252        let mut compiled = rlx_core::flow_bridge::packed_gguf_compile_guard(exec_device, || {
253            Session::new(exec_device).compile_with(graph, &opts)
254        });
255        for (name, data) in &params {
256            compiled.set_param(name, data);
257        }
258        for (name, (bytes, _scheme, _shape)) in &packed {
259            compiled.set_param_typed(name, bytes, rlx_ir::DType::U8);
260        }
261        Ok(Self {
262            compiled,
263            seq,
264            padded_ids: vec![0u32; seq],
265            ids_f32: vec![0f32; seq],
266            last_idx: [0f32; 1],
267        })
268    }
269}
270
271pub struct Llama32Runner {
272    generator: Option<Llama32Generator>,
273    cfg: Llama32Config,
274    sample: SampleOpts,
275    stream: bool,
276    device: Device,
277    packed: Option<Llama32PackedForward>,
278}
279
280impl Llama32Runner {
281    pub fn builder() -> Llama32RunnerBuilder {
282        Llama32RunnerBuilder::default()
283    }
284
285    pub fn config(&self) -> &Llama32Config {
286        &self.cfg
287    }
288
289    pub fn device(&self) -> Device {
290        self.device
291    }
292
293    /// Single prefill forward; returns last-position logits `[vocab]`.
294    pub fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
295        if let Some(p) = self.packed.as_mut() {
296            let n = prompt_ids.len().min(p.seq);
297            p.padded_ids.fill(0);
298            for (i, &t) in prompt_ids.iter().take(n).enumerate() {
299                p.padded_ids[i] = t;
300            }
301            for (dst, &id) in p.ids_f32.iter_mut().zip(p.padded_ids.iter()) {
302                *dst = id as f32;
303            }
304            p.last_idx[0] = n.saturating_sub(1) as f32;
305            let exec_device = p.compiled.device();
306            let out = rlx_core::run_packed_prefill(
307                &mut p.compiled,
308                exec_device,
309                n,
310                p.seq,
311                &[
312                    ("input_ids", p.ids_f32.as_slice()),
313                    ("last_token_idx", p.last_idx.as_slice()),
314                ],
315            );
316            let logits = out
317                .into_iter()
318                .next()
319                .ok_or_else(|| anyhow!("packed forward returned no output"))?;
320            let vocab = self.cfg.vocab_size;
321            if logits.len() < vocab {
322                bail!("logits short: {} < {vocab}", logits.len());
323            }
324            return Ok(logits[..vocab].to_vec());
325        }
326        let generator = self
327            .generator
328            .as_mut()
329            .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
330        generator.prefill_get_last_logits(prompt_ids)
331    }
332
333    pub fn generate_packed(
334        &mut self,
335        prompt_ids: &[u32],
336        n_new: usize,
337        mut on_token: impl FnMut(u32),
338    ) -> Result<Vec<u32>> {
339        if self.packed.is_none() {
340            bail!("generate_packed() only works in packed_weights(true) mode");
341        }
342        let mut history: Vec<u32> = prompt_ids.to_vec();
343        let mut out = Vec::with_capacity(n_new);
344        for _ in 0..n_new {
345            let logits = self.predict_logits(&history)?;
346            let next = rlx_qwen3::sample_token(&logits, self.sample) as u32;
347            on_token(next);
348            history.push(next);
349            out.push(next);
350        }
351        Ok(out)
352    }
353
354    pub fn generate(
355        &mut self,
356        prompt_ids: &[u32],
357        n_new: usize,
358        mut on_token: impl FnMut(u32),
359    ) -> Result<Vec<u32>> {
360        if self.packed.is_some() {
361            return self.generate_packed(prompt_ids, n_new, on_token);
362        }
363        let generator = self
364            .generator
365            .as_mut()
366            .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
367        generator.prefill(prompt_ids);
368        let tokens = if self.stream {
369            generator.generate_cached_with(n_new, self.sample, &mut on_token)?
370        } else {
371            let toks = generator.generate_cached(n_new, self.sample)?;
372            for &t in &toks {
373                on_token(t);
374            }
375            toks
376        };
377        Ok(tokens)
378    }
379}
380
381impl LmRunner for Llama32Runner {
382    fn family(&self) -> &'static str {
383        "llama32"
384    }
385    fn vocab_size(&self) -> usize {
386        self.config().vocab_size
387    }
388    fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
389        Llama32Runner::predict_logits(self, prompt_ids)
390    }
391    fn generate(
392        &mut self,
393        prompt_ids: &[u32],
394        n_new: usize,
395        on_token: &mut dyn FnMut(u32) -> bool,
396    ) -> Result<Vec<u32>> {
397        Llama32Runner::generate(self, prompt_ids, n_new, |tok| {
398            let _ = on_token(tok);
399        })
400    }
401}
402
403fn load_llama32_gguf_config(
404    path: &Path,
405    override_src: Option<&Llama32ConfigSource>,
406) -> Result<(Llama32Config, u64)> {
407    let raw = GgufFile::from_path(path).with_context(|| format!("opening {path:?}"))?;
408    let arch = raw
409        .metadata
410        .get("general.architecture")
411        .and_then(MetaValue::as_str)
412        .unwrap_or("llama");
413    if arch != "llama" {
414        bail!(
415            "{path:?} has architecture {arch:?}; Llama32Runner expects general.architecture=llama"
416        );
417    }
418    let cfg = match override_src {
419        Some(Llama32ConfigSource::Explicit(c)) => c.clone(),
420        Some(Llama32ConfigSource::JsonFile(p)) => {
421            Llama32Config::from_file(p).with_context(|| format!("reading override config {p:?}"))?
422        }
423        Some(Llama32ConfigSource::Embedded) | None => llama32_cfg_from_gguf(&raw)?,
424    };
425    let bytes_est: u64 = raw
426        .tensors
427        .values()
428        .map(|t| (t.n_elements() as u64) * 4)
429        .sum();
430    Ok((cfg, bytes_est))
431}
432
433fn load_llama32_safetensors_config(
434    path: &Path,
435    override_src: Option<&Llama32ConfigSource>,
436) -> Result<(Llama32Config, u64)> {
437    let cfg_path = match override_src {
438        Some(Llama32ConfigSource::Explicit(c)) => {
439            return Ok((c.clone(), default_st_size_estimate(path)));
440        }
441        Some(Llama32ConfigSource::JsonFile(p)) => p.clone(),
442        Some(Llama32ConfigSource::Embedded) => {
443            bail!("ConfigSource::Embedded only valid for GGUF; pass JsonFile for safetensors")
444        }
445        None => path
446            .parent()
447            .ok_or_else(|| anyhow!("weights path has no parent dir"))?
448            .join("config.json"),
449    };
450    let cfg = Llama32Config::from_file(&cfg_path)
451        .with_context(|| format!("reading config {cfg_path:?}"))?;
452    Ok((cfg, default_st_size_estimate(path)))
453}
454
455fn default_st_size_estimate(path: &Path) -> u64 {
456    std::fs::metadata(path).map(|m| m.len()).unwrap_or(0)
457}