Skip to main content

rlx_gemma/
runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use crate::{GemmaConfig, GemmaGenerator, gemma_cfg_from_gguf};
17use anyhow::{Context, Result, anyhow, bail};
18use rlx_cli::{LmRunner, WeightFormat};
19use rlx_core::gguf_support::{
20    GgufModelFamily, ResolveWeightsOptions, assert_gguf_family, gguf_f32_bytes_estimate,
21    resolve_weights_file_with_options,
22};
23use rlx_core::weight_loader::GgufLoader;
24use rlx_flow::CompileProfile;
25use rlx_qwen3::SampleOpts;
26use rlx_runtime::{Device, Session};
27use std::path::{Path, PathBuf};
28
29// ────────────────────────────────────────────────────────────────
30// Gemma runner — Meta Llama 3.x small LMs (1B / 3B).
31// ────────────────────────────────────────────────────────────────
32
33#[derive(Debug, Clone)]
34pub enum GemmaConfigSource {
35    Embedded,
36    JsonFile(PathBuf),
37    Explicit(GemmaConfig),
38}
39
40#[derive(Debug, Clone, Default)]
41pub struct GemmaRunnerBuilder {
42    weights: Option<PathBuf>,
43    config: Option<GemmaConfigSource>,
44    device: Option<Device>,
45    max_seq: Option<usize>,
46    max_memory_gb: Option<f32>,
47    stream: bool,
48    sample: Option<SampleOpts>,
49    format: Option<WeightFormat>,
50    packed_weights: bool,
51}
52
53impl GemmaRunnerBuilder {
54    pub fn weights<P: Into<PathBuf>>(mut self, path: P) -> Self {
55        self.weights = Some(path.into());
56        self
57    }
58
59    pub fn format(mut self, fmt: WeightFormat) -> Self {
60        self.format = Some(fmt);
61        self
62    }
63
64    pub fn config(mut self, src: GemmaConfigSource) -> Self {
65        self.config = Some(src);
66        self
67    }
68
69    pub fn config_value(self, cfg: GemmaConfig) -> Self {
70        self.config(GemmaConfigSource::Explicit(cfg))
71    }
72
73    pub fn device(mut self, d: Device) -> Self {
74        self.device = Some(d);
75        self
76    }
77
78    pub fn max_seq(mut self, n: usize) -> Self {
79        self.max_seq = Some(n);
80        self
81    }
82
83    pub fn max_memory_gb(mut self, gb: f32) -> Self {
84        self.max_memory_gb = Some(gb);
85        self
86    }
87
88    pub fn stream(mut self, on: bool) -> Self {
89        self.stream = on;
90        self
91    }
92
93    pub fn sample(mut self, opts: SampleOpts) -> Self {
94        self.sample = Some(opts);
95        self
96    }
97
98    /// Keep K-quant weights packed in the arena (`Op::DequantMatMul`).
99    /// GGUF only. Uses `Op::DequantMatMul` on the selected device.
100    pub fn packed_weights(mut self, on: bool) -> Self {
101        self.packed_weights = on;
102        self
103    }
104
105    pub fn build(self) -> Result<GemmaRunner> {
106        let resolve = ResolveWeightsOptions {
107            prefer_gguf_substring: Some(rlx_core::DEFAULT_GGUF_PREFER_SUBSTR),
108            ..Default::default()
109        };
110        let weights_path = resolve_weights_file_with_options(
111            self.weights
112                .as_ref()
113                .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?,
114            &resolve,
115        )?;
116        let format = WeightFormat::resolve(&weights_path, self.format)?;
117        let device = self.device.unwrap_or(Device::Cpu);
118        let max_seq = self.max_seq.unwrap_or(128);
119        let stream = self.stream;
120        let sample = self.sample.unwrap_or_else(SampleOpts::greedy);
121
122        let (cfg, total_bytes_estimate) = match format {
123            WeightFormat::Gguf => load_gemma_gguf_config(&weights_path, self.config.as_ref())?,
124            WeightFormat::Safetensors => {
125                load_gemma_safetensors_config(&weights_path, self.config.as_ref())?
126            }
127        };
128
129        if let Some(cap_gb) = self.max_memory_gb {
130            let est_gb = total_bytes_estimate as f32 / (1024.0 * 1024.0 * 1024.0);
131            if est_gb > cap_gb {
132                bail!(
133                    "weights would dequant to ~{est_gb:.1} GB at F32, exceeds cap {cap_gb:.1} GB"
134                );
135            }
136        }
137
138        crate::capabilities::validate_device(&cfg, device, self.packed_weights)?;
139
140        let path_str = weights_path
141            .to_str()
142            .ok_or_else(|| anyhow!("non-utf8 weights path"))?;
143        let generator = if self.packed_weights {
144            None
145        } else {
146            Some(
147                GemmaGenerator::from_path(cfg.clone(), path_str, device)?
148                    .with_prefill_cache(2)
149                    .with_decode_cache(max_seq + 64),
150            )
151        };
152
153        let packed = if self.packed_weights {
154            if !matches!(format, WeightFormat::Gguf) {
155                bail!(
156                    "packed_weights(true) requires a .gguf file; got {:?} for {:?}",
157                    format,
158                    weights_path
159                );
160            }
161            eprintln!(
162                "[gemma-runner] packed_weights=true — compiling prefill graph with \
163                 Op::DequantMatMul on {device:?}"
164            );
165            Some(GemmaPackedForward::build(
166                &cfg,
167                &weights_path,
168                max_seq,
169                device,
170            )?)
171        } else {
172            None
173        };
174
175        Ok(GemmaRunner {
176            generator,
177            cfg,
178            sample,
179            stream,
180            device,
181            packed,
182        })
183    }
184}
185
186struct GemmaPackedForward {
187    compiled: rlx_runtime::CompiledGraph,
188    seq: usize,
189}
190
191impl GemmaPackedForward {
192    fn build(cfg: &GemmaConfig, weights_path: &Path, seq: usize, device: Device) -> Result<Self> {
193        use crate::build_gemma_graph_sized_packed;
194        let mut loader = GgufLoader::from_file(
195            weights_path
196                .to_str()
197                .ok_or_else(|| anyhow!("non-utf8 weights path"))?,
198        )?;
199        let mut packed = std::collections::HashMap::new();
200        // `last_logits_only=false` so the runner can extract the row
201        // at the real prompt's last index. Same fix as rlx-qwen3 /
202        // rlx-llama32 — see `predict_logits` for the rationale.
203        let (graph, params) =
204            build_gemma_graph_sized_packed(cfg, &mut loader, 1, seq, true, false, &mut packed)?;
205        let opts = rlx_core::flow_bridge::compile_options_for_profile(
206            &CompileProfile::gemma_prefill(),
207            device,
208        );
209        let mut compiled = Session::new(device).compile_with(graph, &opts);
210        for (name, data) in &params {
211            compiled.set_param(name, data);
212        }
213        for (name, (bytes, _scheme, _shape)) in &packed {
214            compiled.set_param_typed(name, bytes, rlx_ir::DType::U8);
215        }
216        Ok(Self { compiled, seq })
217    }
218}
219
220pub struct GemmaRunner {
221    generator: Option<GemmaGenerator>,
222    cfg: GemmaConfig,
223    sample: SampleOpts,
224    stream: bool,
225    device: Device,
226    packed: Option<GemmaPackedForward>,
227}
228
229impl GemmaRunner {
230    pub fn builder() -> GemmaRunnerBuilder {
231        GemmaRunnerBuilder::default()
232    }
233
234    pub fn config(&self) -> &GemmaConfig {
235        &self.cfg
236    }
237
238    pub fn device(&self) -> Device {
239        self.device
240    }
241
242    /// Single prefill forward; returns last-position logits `[vocab]`.
243    pub fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
244        if let Some(p) = self.packed.as_mut() {
245            // Zero-pad after the real prompt + extract logits at the
246            // real last index. Same fix as rlx-qwen3 / rlx-llama32.
247            let n = prompt_ids.len().min(p.seq);
248            let last = n.saturating_sub(1);
249            let mut padded = vec![0u32; p.seq];
250            for (i, &t) in prompt_ids.iter().take(p.seq).enumerate() {
251                padded[i] = t;
252            }
253            let ids_f32: Vec<f32> = padded.iter().map(|&i| i as f32).collect();
254            let out = p.compiled.run(&[("input_ids", ids_f32.as_slice())]);
255            let logits = out
256                .into_iter()
257                .next()
258                .ok_or_else(|| anyhow!("packed forward returned no output"))?;
259            let vocab = self.cfg.vocab_size;
260            let expected = p.seq * vocab;
261            if logits.len() < expected {
262                bail!("logits short: {} < {expected}", logits.len());
263            }
264            let start = last * vocab;
265            return Ok(logits[start..start + vocab].to_vec());
266        }
267        let generator = self
268            .generator
269            .as_mut()
270            .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
271        generator.prefill_get_last_logits(prompt_ids)
272    }
273
274    pub fn generate_packed(
275        &mut self,
276        prompt_ids: &[u32],
277        n_new: usize,
278        mut on_token: impl FnMut(u32),
279    ) -> Result<Vec<u32>> {
280        if self.packed.is_none() {
281            bail!("generate_packed() only works in packed_weights(true) mode");
282        }
283        let mut history: Vec<u32> = prompt_ids.to_vec();
284        let mut out = Vec::with_capacity(n_new);
285        for _ in 0..n_new {
286            let logits = self.predict_logits(&history)?;
287            let next = rlx_qwen3::sample_token(&logits, self.sample) as u32;
288            on_token(next);
289            history.push(next);
290            out.push(next);
291        }
292        Ok(out)
293    }
294
295    pub fn generate(
296        &mut self,
297        prompt_ids: &[u32],
298        n_new: usize,
299        mut on_token: impl FnMut(u32),
300    ) -> Result<Vec<u32>> {
301        if self.packed.is_some() {
302            return self.generate_packed(prompt_ids, n_new, on_token);
303        }
304        let generator = self
305            .generator
306            .as_mut()
307            .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
308        generator.prefill(prompt_ids);
309        let tokens = if self.stream {
310            generator.generate_cached_with(n_new, self.sample, &mut on_token)?
311        } else {
312            let toks = generator.generate_cached(n_new, self.sample)?;
313            for &t in &toks {
314                on_token(t);
315            }
316            toks
317        };
318        Ok(tokens)
319    }
320}
321
322impl LmRunner for GemmaRunner {
323    fn family(&self) -> &'static str {
324        "gemma"
325    }
326    fn vocab_size(&self) -> usize {
327        self.config().vocab_size
328    }
329    fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
330        GemmaRunner::predict_logits(self, prompt_ids)
331    }
332    fn generate(
333        &mut self,
334        prompt_ids: &[u32],
335        n_new: usize,
336        on_token: &mut dyn FnMut(u32) -> bool,
337    ) -> Result<Vec<u32>> {
338        // Inherent generate ignores stop signal — drop the bool.
339        GemmaRunner::generate(self, prompt_ids, n_new, |tok| {
340            let _ = on_token(tok);
341        })
342    }
343}
344
345fn load_gemma_gguf_config(
346    path: &Path,
347    override_src: Option<&GemmaConfigSource>,
348) -> Result<(GemmaConfig, u64)> {
349    let raw = assert_gguf_family(path, GgufModelFamily::Gemma)?;
350    let cfg = match override_src {
351        Some(GemmaConfigSource::Explicit(c)) => c.clone(),
352        Some(GemmaConfigSource::JsonFile(p)) => {
353            GemmaConfig::from_file(p).with_context(|| format!("reading override config {p:?}"))?
354        }
355        Some(GemmaConfigSource::Embedded) | None => gemma_cfg_from_gguf(&raw)?,
356    };
357    Ok((cfg, gguf_f32_bytes_estimate(&raw)))
358}
359
360fn load_gemma_safetensors_config(
361    path: &Path,
362    override_src: Option<&GemmaConfigSource>,
363) -> Result<(GemmaConfig, u64)> {
364    let cfg_path = match override_src {
365        Some(GemmaConfigSource::Explicit(c)) => {
366            return Ok((c.clone(), default_st_size_estimate(path)));
367        }
368        Some(GemmaConfigSource::JsonFile(p)) => p.clone(),
369        Some(GemmaConfigSource::Embedded) => {
370            bail!("ConfigSource::Embedded only valid for GGUF; pass JsonFile for safetensors")
371        }
372        None => path
373            .parent()
374            .ok_or_else(|| anyhow!("weights path has no parent dir"))?
375            .join("config.json"),
376    };
377    let cfg = GemmaConfig::from_file(&cfg_path)
378        .with_context(|| format!("reading config {cfg_path:?}"))?;
379    Ok((cfg, default_st_size_estimate(path)))
380}
381
382fn default_st_size_estimate(path: &Path) -> u64 {
383    std::fs::metadata(path).map(|m| m.len()).unwrap_or(0)
384}