Skip to main content

rlx_llama32/
runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3
4use crate::{Llama32Config, Llama32Generator, llama32_cfg_from_gguf};
5use anyhow::{Context, Result, anyhow, bail};
6use rlx_cli::{LmRunner, WeightFormat};
7use rlx_core::weight_loader::GgufLoader;
8use rlx_flow::CompileProfile;
9use rlx_gguf::{GgufFile, MetaValue};
10use rlx_qwen3::SampleOpts;
11use rlx_runtime::{Device, Session};
12use std::path::{Path, PathBuf};
13
14// ────────────────────────────────────────────────────────────────
15// LLaMA-3.2 runner — Meta Llama 3.x small LMs (1B / 3B).
16// ────────────────────────────────────────────────────────────────
17
18#[derive(Debug, Clone)]
19pub enum Llama32ConfigSource {
20    Embedded,
21    JsonFile(PathBuf),
22    Explicit(Llama32Config),
23}
24
25#[derive(Debug, Clone, Default)]
26pub struct Llama32RunnerBuilder {
27    weights: Option<PathBuf>,
28    config: Option<Llama32ConfigSource>,
29    device: Option<Device>,
30    max_seq: Option<usize>,
31    max_memory_gb: Option<f32>,
32    stream: bool,
33    sample: Option<SampleOpts>,
34    format: Option<WeightFormat>,
35    packed_weights: bool,
36}
37
38impl Llama32RunnerBuilder {
39    pub fn weights<P: Into<PathBuf>>(mut self, path: P) -> Self {
40        self.weights = Some(path.into());
41        self
42    }
43
44    pub fn format(mut self, fmt: WeightFormat) -> Self {
45        self.format = Some(fmt);
46        self
47    }
48
49    pub fn config(mut self, src: Llama32ConfigSource) -> Self {
50        self.config = Some(src);
51        self
52    }
53
54    pub fn config_value(self, cfg: Llama32Config) -> Self {
55        self.config(Llama32ConfigSource::Explicit(cfg))
56    }
57
58    pub fn device(mut self, d: Device) -> Self {
59        self.device = Some(d);
60        self
61    }
62
63    pub fn max_seq(mut self, n: usize) -> Self {
64        self.max_seq = Some(n);
65        self
66    }
67
68    pub fn max_memory_gb(mut self, gb: f32) -> Self {
69        self.max_memory_gb = Some(gb);
70        self
71    }
72
73    pub fn stream(mut self, on: bool) -> Self {
74        self.stream = on;
75        self
76    }
77
78    pub fn sample(mut self, opts: SampleOpts) -> Self {
79        self.sample = Some(opts);
80        self
81    }
82
83    /// Keep K-quant weights packed in the arena (`Op::DequantMatMul`).
84    /// GGUF only. Supported on CPU, Metal, and MLX.
85    pub fn packed_weights(mut self, on: bool) -> Self {
86        self.packed_weights = on;
87        self
88    }
89
90    pub fn build(self) -> Result<Llama32Runner> {
91        let weights_path = self
92            .weights
93            .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?;
94        let format = match self.format {
95            Some(f) => f,
96            None => WeightFormat::from_path(&weights_path)?,
97        };
98        let device = self.device.unwrap_or(Device::Cpu);
99        let max_seq = self.max_seq.unwrap_or(128);
100        let stream = self.stream;
101        let sample = self.sample.unwrap_or_else(SampleOpts::greedy);
102
103        let (cfg, total_bytes_estimate) = match format {
104            WeightFormat::Gguf => load_llama32_gguf_config(&weights_path, self.config.as_ref())?,
105            WeightFormat::Safetensors => {
106                load_llama32_safetensors_config(&weights_path, self.config.as_ref())?
107            }
108        };
109
110        if let Some(cap_gb) = self.max_memory_gb {
111            let est_gb = total_bytes_estimate as f32 / (1024.0 * 1024.0 * 1024.0);
112            if est_gb > cap_gb {
113                bail!(
114                    "weights would dequant to ~{est_gb:.1} GB at F32, exceeds cap {cap_gb:.1} GB"
115                );
116            }
117        }
118
119        crate::validate_device(&cfg, device, self.packed_weights)?;
120
121        let path_str = weights_path
122            .to_str()
123            .ok_or_else(|| anyhow!("non-utf8 weights path"))?;
124        let generator = if self.packed_weights {
125            None
126        } else {
127            Some(
128                Llama32Generator::from_path(cfg.clone(), path_str, device)?
129                    .with_prefill_cache(2)
130                    .with_decode_cache(max_seq + 64),
131            )
132        };
133
134        let packed = if self.packed_weights {
135            if !matches!(format, WeightFormat::Gguf) {
136                bail!(
137                    "packed_weights(true) requires a .gguf file; got {:?} for {:?}",
138                    format,
139                    weights_path
140                );
141            }
142            eprintln!(
143                "[llama32-runner] packed_weights=true — compiling prefill graph with \
144                 Op::DequantMatMul on {device:?}"
145            );
146            Some(Llama32PackedForward::build(
147                &cfg,
148                &weights_path,
149                max_seq,
150                device,
151            )?)
152        } else {
153            None
154        };
155
156        Ok(Llama32Runner {
157            generator,
158            cfg,
159            sample,
160            stream,
161            device,
162            packed,
163        })
164    }
165}
166
167struct Llama32PackedForward {
168    compiled: rlx_runtime::CompiledGraph,
169    seq: usize,
170}
171
172impl Llama32PackedForward {
173    fn build(cfg: &Llama32Config, weights_path: &Path, seq: usize, device: Device) -> Result<Self> {
174        use crate::build_llama32_graph_sized_packed;
175        let mut loader = GgufLoader::from_file(
176            weights_path
177                .to_str()
178                .ok_or_else(|| anyhow!("non-utf8 weights path"))?,
179        )?;
180        let mut packed = std::collections::HashMap::new();
181        let (graph, params) =
182            build_llama32_graph_sized_packed(cfg, &mut loader, 1, seq, true, true, &mut packed)?;
183        let opts = rlx_core::flow_bridge::compile_options_for_profile(
184            &CompileProfile::llama32_prefill(),
185            device,
186        );
187        let mut compiled = Session::new(device).compile_with(graph, &opts);
188        for (name, data) in &params {
189            compiled.set_param(name, data);
190        }
191        for (name, (bytes, _scheme, _shape)) in &packed {
192            compiled.set_param_typed(name, bytes, rlx_ir::DType::U8);
193        }
194        Ok(Self { compiled, seq })
195    }
196}
197
198pub struct Llama32Runner {
199    generator: Option<Llama32Generator>,
200    cfg: Llama32Config,
201    sample: SampleOpts,
202    stream: bool,
203    device: Device,
204    packed: Option<Llama32PackedForward>,
205}
206
207impl Llama32Runner {
208    pub fn builder() -> Llama32RunnerBuilder {
209        Llama32RunnerBuilder::default()
210    }
211
212    pub fn config(&self) -> &Llama32Config {
213        &self.cfg
214    }
215
216    pub fn device(&self) -> Device {
217        self.device
218    }
219
220    /// Single prefill forward; returns last-position logits `[vocab]`.
221    pub fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
222        if let Some(p) = self.packed.as_mut() {
223            let mut padded = vec![*prompt_ids.first().unwrap_or(&0); p.seq];
224            for (i, &t) in prompt_ids.iter().take(p.seq).enumerate() {
225                padded[i] = t;
226            }
227            let ids_f32: Vec<f32> = padded.iter().map(|&i| i as f32).collect();
228            let out = p.compiled.run(&[("input_ids", ids_f32.as_slice())]);
229            let logits = out
230                .into_iter()
231                .next()
232                .ok_or_else(|| anyhow!("packed forward returned no output"))?;
233            let vocab = self.cfg.vocab_size;
234            if logits.len() < vocab {
235                bail!("logits short: {} < {vocab}", logits.len());
236            }
237            return Ok(logits[..vocab].to_vec());
238        }
239        let generator = self
240            .generator
241            .as_mut()
242            .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
243        generator.prefill_get_last_logits(prompt_ids)
244    }
245
246    pub fn generate_packed(
247        &mut self,
248        prompt_ids: &[u32],
249        n_new: usize,
250        mut on_token: impl FnMut(u32),
251    ) -> Result<Vec<u32>> {
252        if self.packed.is_none() {
253            bail!("generate_packed() only works in packed_weights(true) mode");
254        }
255        let mut history: Vec<u32> = prompt_ids.to_vec();
256        let mut out = Vec::with_capacity(n_new);
257        for _ in 0..n_new {
258            let logits = self.predict_logits(&history)?;
259            let next = rlx_qwen3::sample_token(&logits, self.sample) as u32;
260            on_token(next);
261            history.push(next);
262            out.push(next);
263        }
264        Ok(out)
265    }
266
267    pub fn generate(
268        &mut self,
269        prompt_ids: &[u32],
270        n_new: usize,
271        mut on_token: impl FnMut(u32),
272    ) -> Result<Vec<u32>> {
273        if self.packed.is_some() {
274            return self.generate_packed(prompt_ids, n_new, on_token);
275        }
276        let generator = self
277            .generator
278            .as_mut()
279            .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
280        generator.prefill(prompt_ids);
281        let tokens = if self.stream {
282            generator.generate_cached_with(n_new, self.sample, &mut on_token)?
283        } else {
284            let toks = generator.generate_cached(n_new, self.sample)?;
285            for &t in &toks {
286                on_token(t);
287            }
288            toks
289        };
290        Ok(tokens)
291    }
292}
293
294impl LmRunner for Llama32Runner {
295    fn family(&self) -> &'static str {
296        "llama32"
297    }
298    fn vocab_size(&self) -> usize {
299        self.config().vocab_size
300    }
301    fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
302        Llama32Runner::predict_logits(self, prompt_ids)
303    }
304    fn generate(
305        &mut self,
306        prompt_ids: &[u32],
307        n_new: usize,
308        on_token: &mut dyn FnMut(u32) -> bool,
309    ) -> Result<Vec<u32>> {
310        Llama32Runner::generate(self, prompt_ids, n_new, |tok| {
311            let _ = on_token(tok);
312        })
313    }
314}
315
316fn load_llama32_gguf_config(
317    path: &Path,
318    override_src: Option<&Llama32ConfigSource>,
319) -> Result<(Llama32Config, u64)> {
320    let raw = GgufFile::from_path(path).with_context(|| format!("opening {path:?}"))?;
321    let arch = raw
322        .metadata
323        .get("general.architecture")
324        .and_then(MetaValue::as_str)
325        .unwrap_or("llama");
326    if arch != "llama" {
327        bail!(
328            "{path:?} has architecture {arch:?}; Llama32Runner expects general.architecture=llama"
329        );
330    }
331    let cfg = match override_src {
332        Some(Llama32ConfigSource::Explicit(c)) => c.clone(),
333        Some(Llama32ConfigSource::JsonFile(p)) => {
334            Llama32Config::from_file(p).with_context(|| format!("reading override config {p:?}"))?
335        }
336        Some(Llama32ConfigSource::Embedded) | None => llama32_cfg_from_gguf(&raw)?,
337    };
338    let bytes_est: u64 = raw
339        .tensors
340        .values()
341        .map(|t| (t.n_elements() as u64) * 4)
342        .sum();
343    Ok((cfg, bytes_est))
344}
345
346fn load_llama32_safetensors_config(
347    path: &Path,
348    override_src: Option<&Llama32ConfigSource>,
349) -> Result<(Llama32Config, u64)> {
350    let cfg_path = match override_src {
351        Some(Llama32ConfigSource::Explicit(c)) => {
352            return Ok((c.clone(), default_st_size_estimate(path)));
353        }
354        Some(Llama32ConfigSource::JsonFile(p)) => p.clone(),
355        Some(Llama32ConfigSource::Embedded) => {
356            bail!("ConfigSource::Embedded only valid for GGUF; pass JsonFile for safetensors")
357        }
358        None => path
359            .parent()
360            .ok_or_else(|| anyhow!("weights path has no parent dir"))?
361            .join("config.json"),
362    };
363    let cfg = Llama32Config::from_file(&cfg_path)
364        .with_context(|| format!("reading config {cfg_path:?}"))?;
365    Ok((cfg, default_st_size_estimate(path)))
366}
367
368fn default_st_size_estimate(path: &Path) -> u64 {
369    std::fs::metadata(path).map(|m| m.len()).unwrap_or(0)
370}