Skip to main content

rlx_gemma/
runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use crate::{GemmaConfig, GemmaGenerator, gemma_cfg_from_gguf};
17use anyhow::{Context, Result, anyhow, bail};
18use rlx_cli::{LmRunner, WeightFormat};
19use rlx_core::gguf_support::{
20    GgufModelFamily, ResolveWeightsOptions, assert_gguf_family, gguf_f32_bytes_estimate,
21    resolve_weights_file_with_options,
22};
23use rlx_qwen3::SampleOpts;
24use rlx_runtime::Device;
25use std::path::{Path, PathBuf};
26
27// ────────────────────────────────────────────────────────────────
28// Gemma runner — Meta Llama 3.x small LMs (1B / 3B).
29// ────────────────────────────────────────────────────────────────
30
31/// Where to load the Gemma config from. Alias of the shared
32/// `rlx_runtime::ConfigSource<T>` — same `Embedded | JsonFile | Explicit(T)`
33/// shape as the other family `*ConfigSource` enums.
34pub type GemmaConfigSource = rlx_runtime::ConfigSource<GemmaConfig>;
35
36#[derive(Debug, Clone, Default)]
37pub struct GemmaRunnerBuilder {
38    weights: Option<PathBuf>,
39    config: Option<GemmaConfigSource>,
40    device: Option<Device>,
41    max_seq: Option<usize>,
42    max_memory_gb: Option<f32>,
43    stream: bool,
44    sample: Option<SampleOpts>,
45    format: Option<WeightFormat>,
46    packed_weights: bool,
47}
48
49impl GemmaRunnerBuilder {
50    pub fn weights<P: Into<PathBuf>>(mut self, path: P) -> Self {
51        self.weights = Some(path.into());
52        self
53    }
54
55    pub fn format(mut self, fmt: WeightFormat) -> Self {
56        self.format = Some(fmt);
57        self
58    }
59
60    pub fn config(mut self, src: GemmaConfigSource) -> Self {
61        self.config = Some(src);
62        self
63    }
64
65    pub fn config_value(self, cfg: GemmaConfig) -> Self {
66        self.config(GemmaConfigSource::Explicit(cfg))
67    }
68
69    pub fn device(mut self, d: Device) -> Self {
70        self.device = Some(d);
71        self
72    }
73
74    pub fn max_seq(mut self, n: usize) -> Self {
75        self.max_seq = Some(n);
76        self
77    }
78
79    pub fn max_memory_gb(mut self, gb: f32) -> Self {
80        self.max_memory_gb = Some(gb);
81        self
82    }
83
84    pub fn stream(mut self, on: bool) -> Self {
85        self.stream = on;
86        self
87    }
88
89    pub fn sample(mut self, opts: SampleOpts) -> Self {
90        self.sample = Some(opts);
91        self
92    }
93
94    /// Keep K-quant weights packed in the arena (`Op::DequantMatMul`).
95    /// GGUF only. Uses `Op::DequantMatMul` on the selected device.
96    pub fn packed_weights(mut self, on: bool) -> Self {
97        self.packed_weights = on;
98        self
99    }
100
101    pub fn build(self) -> Result<GemmaRunner> {
102        let resolve = ResolveWeightsOptions {
103            prefer_gguf_substring: Some(rlx_core::DEFAULT_GGUF_PREFER_SUBSTR),
104            ..Default::default()
105        };
106        let weights_path = resolve_weights_file_with_options(
107            self.weights
108                .as_ref()
109                .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?,
110            &resolve,
111        )?;
112        let format = WeightFormat::resolve(&weights_path, self.format)?;
113        let device = self.device.unwrap_or(Device::Cpu);
114        let max_seq = self.max_seq.unwrap_or(128);
115        let stream = self.stream;
116        let sample = self.sample.unwrap_or_else(SampleOpts::greedy);
117
118        let (cfg, total_bytes_estimate) = match format {
119            WeightFormat::Gguf => load_gemma_gguf_config(&weights_path, self.config.as_ref())?,
120            WeightFormat::Safetensors => {
121                load_gemma_safetensors_config(&weights_path, self.config.as_ref())?
122            }
123        };
124
125        if let Some(cap_gb) = self.max_memory_gb {
126            let est_gb = total_bytes_estimate as f32 / (1024.0 * 1024.0 * 1024.0);
127            if est_gb > cap_gb {
128                bail!(
129                    "weights would dequant to ~{est_gb:.1} GB at F32, exceeds cap {cap_gb:.1} GB"
130                );
131            }
132        }
133
134        crate::capabilities::validate_device(&cfg, device, self.packed_weights)?;
135
136        let path_str = weights_path
137            .to_str()
138            .ok_or_else(|| anyhow!("non-utf8 weights path"))?;
139        let generator = if self.packed_weights {
140            None
141        } else {
142            Some(
143                GemmaGenerator::from_path(cfg.clone(), path_str, device)?
144                    .with_inference_caches(max_seq),
145            )
146        };
147
148        let packed = if self.packed_weights {
149            if !matches!(format, WeightFormat::Gguf) {
150                bail!(
151                    "packed_weights(true) requires a .gguf file; got {:?} for {:?}",
152                    format,
153                    weights_path
154                );
155            }
156            eprintln!(
157                "[gemma-runner] packed_weights=true — Q4 prefill + bucketed decode on {device:?}"
158            );
159            Some(crate::packed_session::GemmaPackedSession::build(
160                cfg.clone(),
161                &weights_path,
162                max_seq,
163                device,
164            )?)
165        } else {
166            None
167        };
168
169        Ok(GemmaRunner {
170            generator,
171            cfg,
172            sample,
173            stream,
174            device,
175            packed,
176        })
177    }
178}
179
180pub struct GemmaRunner {
181    generator: Option<GemmaGenerator>,
182    cfg: GemmaConfig,
183    sample: SampleOpts,
184    stream: bool,
185    device: Device,
186    packed: Option<crate::packed_session::GemmaPackedSession>,
187}
188
189impl GemmaRunner {
190    pub fn builder() -> GemmaRunnerBuilder {
191        GemmaRunnerBuilder::default()
192    }
193
194    pub fn config(&self) -> &GemmaConfig {
195        &self.cfg
196    }
197
198    pub fn device(&self) -> Device {
199        self.device
200    }
201
202    /// Single prefill forward; returns last-position logits `[vocab]`.
203    pub fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
204        if let Some(p) = self.packed.as_mut() {
205            return p.predict_logits(prompt_ids);
206        }
207        let generator = self
208            .generator
209            .as_mut()
210            .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
211        generator.prefill_get_last_logits(prompt_ids)
212    }
213
214    pub fn generate_packed(
215        &mut self,
216        prompt_ids: &[u32],
217        n_new: usize,
218        on_token: impl FnMut(u32),
219    ) -> Result<Vec<u32>> {
220        if self.packed.is_none() {
221            bail!("generate_packed() only works in packed_weights(true) mode");
222        }
223        let sample = self.sample;
224        self.packed
225            .as_mut()
226            .unwrap()
227            .generate(prompt_ids, n_new, sample, on_token)
228    }
229
230    pub fn generate(
231        &mut self,
232        prompt_ids: &[u32],
233        n_new: usize,
234        mut on_token: impl FnMut(u32),
235    ) -> Result<Vec<u32>> {
236        if self.packed.is_some() {
237            return self.generate_packed(prompt_ids, n_new, on_token);
238        }
239        let generator = self
240            .generator
241            .as_mut()
242            .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
243        generator.prefill(prompt_ids);
244        let tokens = if self.stream {
245            generator.generate_cached_with(n_new, self.sample, &mut on_token)?
246        } else {
247            let toks = generator.generate_cached(n_new, self.sample)?;
248            for &t in &toks {
249                on_token(t);
250            }
251            toks
252        };
253        Ok(tokens)
254    }
255
256    /// Generate after splicing vision/audio rows into pre-scaled text embeddings.
257    pub fn generate_from_embeds(
258        &mut self,
259        prompt_ids: &[u32],
260        inputs_embeds: &[f32],
261        n_new: usize,
262        mut on_token: impl FnMut(u32),
263    ) -> Result<Vec<u32>> {
264        if self.packed.is_some() {
265            bail!("generate_from_embeds is not supported with packed_weights(true)");
266        }
267        let generator = self
268            .generator
269            .as_mut()
270            .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
271        let tokens = if self.stream {
272            generator.generate_from_embeds_with(
273                prompt_ids,
274                inputs_embeds,
275                n_new,
276                self.sample,
277                &mut on_token,
278            )?
279        } else {
280            let toks =
281                generator.generate_from_embeds(prompt_ids, inputs_embeds, n_new, self.sample)?;
282            for &t in &toks {
283                on_token(t);
284            }
285            toks
286        };
287        Ok(tokens)
288    }
289
290    /// Build fused inputs + run generation (text LM weights must include `model.embed_tokens.weight`).
291    pub fn generate_multimodal(
292        &mut self,
293        mm_cfg: &crate::multimodal::GemmaMultimodalConfig,
294        token_ids: &[u32],
295        image_embeds: &[f32],
296        audio_embeds: &[f32],
297        video_embeds: &[f32],
298        n_new: usize,
299        mut on_token: impl FnMut(u32),
300    ) -> Result<Vec<u32>> {
301        let generator = self
302            .generator
303            .as_ref()
304            .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
305        let embeds = crate::multimodal_embed::build_multimodal_inputs_embeds(
306            generator.weights_cache(),
307            &self.cfg,
308            mm_cfg,
309            token_ids,
310            image_embeds,
311            audio_embeds,
312            video_embeds,
313        )?;
314        let attn_bias = crate::multimodal_mask::build_multimodal_prefill_attn_bias(
315            token_ids, &self.cfg, mm_cfg, 1,
316        );
317        let generator = self
318            .generator
319            .as_mut()
320            .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
321        let tokens = if self.stream {
322            generator.generate_from_embeds_with_bias_and_callback(
323                token_ids,
324                &embeds,
325                attn_bias,
326                n_new,
327                self.sample,
328                &mut on_token,
329            )?
330        } else {
331            let toks = generator.generate_from_embeds_with_bias(
332                token_ids,
333                &embeds,
334                attn_bias,
335                n_new,
336                self.sample,
337            )?;
338            for &t in &toks {
339                on_token(t);
340            }
341            toks
342        };
343        Ok(tokens)
344    }
345}
346
347impl LmRunner for GemmaRunner {
348    fn family(&self) -> &'static str {
349        "gemma"
350    }
351    fn vocab_size(&self) -> usize {
352        self.config().vocab_size
353    }
354    fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
355        GemmaRunner::predict_logits(self, prompt_ids)
356    }
357    fn generate(
358        &mut self,
359        prompt_ids: &[u32],
360        n_new: usize,
361        on_token: &mut dyn FnMut(u32) -> bool,
362    ) -> Result<Vec<u32>> {
363        // Inherent generate ignores stop signal — drop the bool.
364        GemmaRunner::generate(self, prompt_ids, n_new, |tok| {
365            let _ = on_token(tok);
366        })
367    }
368}
369
370fn load_gemma_gguf_config(
371    path: &Path,
372    override_src: Option<&GemmaConfigSource>,
373) -> Result<(GemmaConfig, u64)> {
374    let raw = assert_gguf_family(path, GgufModelFamily::Gemma)?;
375    let cfg = match override_src {
376        Some(GemmaConfigSource::Explicit(c)) => c.clone(),
377        Some(GemmaConfigSource::JsonFile(p)) => {
378            GemmaConfig::from_file(p).with_context(|| format!("reading override config {p:?}"))?
379        }
380        Some(GemmaConfigSource::Embedded) | None => gemma_cfg_from_gguf(&raw)?,
381    };
382    Ok((cfg, gguf_f32_bytes_estimate(&raw)))
383}
384
385fn load_gemma_safetensors_config(
386    path: &Path,
387    override_src: Option<&GemmaConfigSource>,
388) -> Result<(GemmaConfig, u64)> {
389    let cfg_path = match override_src {
390        Some(GemmaConfigSource::Explicit(c)) => {
391            return Ok((c.clone(), default_st_size_estimate(path)));
392        }
393        Some(GemmaConfigSource::JsonFile(p)) => p.clone(),
394        Some(GemmaConfigSource::Embedded) => {
395            bail!("ConfigSource::Embedded only valid for GGUF; pass JsonFile for safetensors")
396        }
397        None => path
398            .parent()
399            .ok_or_else(|| anyhow!("weights path has no parent dir"))?
400            .join("config.json"),
401    };
402    let cfg = GemmaConfig::from_file(&cfg_path)
403        .with_context(|| format!("reading config {cfg_path:?}"))?;
404    Ok((cfg, default_st_size_estimate(path)))
405}
406
407fn default_st_size_estimate(path: &Path) -> u64 {
408    std::fs::metadata(path).map(|m| m.len()).unwrap_or(0)
409}