rlx-gemma 0.2.0

Gemma / Gemma 2 causal LMs for RLX
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

use crate::{GemmaConfig, GemmaGenerator, gemma_cfg_from_gguf};
use anyhow::{Context, Result, anyhow, bail};
use rlx_cli::{LmRunner, WeightFormat};
use rlx_core::gguf_support::{
    GgufModelFamily, ResolveWeightsOptions, assert_gguf_family, gguf_f32_bytes_estimate,
    resolve_weights_file_with_options,
};
use rlx_core::weight_loader::GgufLoader;
use rlx_flow::CompileProfile;
use rlx_qwen3::SampleOpts;
use rlx_runtime::{Device, Session};
use std::path::{Path, PathBuf};

// ────────────────────────────────────────────────────────────────
// Gemma runner — Meta Llama 3.x small LMs (1B / 3B).
// ────────────────────────────────────────────────────────────────

#[derive(Debug, Clone)]
pub enum GemmaConfigSource {
    Embedded,
    JsonFile(PathBuf),
    Explicit(GemmaConfig),
}

#[derive(Debug, Clone, Default)]
pub struct GemmaRunnerBuilder {
    weights: Option<PathBuf>,
    config: Option<GemmaConfigSource>,
    device: Option<Device>,
    max_seq: Option<usize>,
    max_memory_gb: Option<f32>,
    stream: bool,
    sample: Option<SampleOpts>,
    format: Option<WeightFormat>,
    packed_weights: bool,
}

impl GemmaRunnerBuilder {
    pub fn weights<P: Into<PathBuf>>(mut self, path: P) -> Self {
        self.weights = Some(path.into());
        self
    }

    pub fn format(mut self, fmt: WeightFormat) -> Self {
        self.format = Some(fmt);
        self
    }

    pub fn config(mut self, src: GemmaConfigSource) -> Self {
        self.config = Some(src);
        self
    }

    pub fn config_value(self, cfg: GemmaConfig) -> Self {
        self.config(GemmaConfigSource::Explicit(cfg))
    }

    pub fn device(mut self, d: Device) -> Self {
        self.device = Some(d);
        self
    }

    pub fn max_seq(mut self, n: usize) -> Self {
        self.max_seq = Some(n);
        self
    }

    pub fn max_memory_gb(mut self, gb: f32) -> Self {
        self.max_memory_gb = Some(gb);
        self
    }

    pub fn stream(mut self, on: bool) -> Self {
        self.stream = on;
        self
    }

    pub fn sample(mut self, opts: SampleOpts) -> Self {
        self.sample = Some(opts);
        self
    }

    /// Keep K-quant weights packed in the arena (`Op::DequantMatMul`).
    /// GGUF only. Uses `Op::DequantMatMul` on the selected device.
    pub fn packed_weights(mut self, on: bool) -> Self {
        self.packed_weights = on;
        self
    }

    pub fn build(self) -> Result<GemmaRunner> {
        let resolve = ResolveWeightsOptions {
            prefer_gguf_substring: Some(rlx_core::DEFAULT_GGUF_PREFER_SUBSTR),
            ..Default::default()
        };
        let weights_path = resolve_weights_file_with_options(
            self.weights
                .as_ref()
                .ok_or_else(|| anyhow!("weights path required (call .weights(...))"))?,
            &resolve,
        )?;
        let format = WeightFormat::resolve(&weights_path, self.format)?;
        let device = self.device.unwrap_or(Device::Cpu);
        let max_seq = self.max_seq.unwrap_or(128);
        let stream = self.stream;
        let sample = self.sample.unwrap_or_else(SampleOpts::greedy);

        let (cfg, total_bytes_estimate) = match format {
            WeightFormat::Gguf => load_gemma_gguf_config(&weights_path, self.config.as_ref())?,
            WeightFormat::Safetensors => {
                load_gemma_safetensors_config(&weights_path, self.config.as_ref())?
            }
        };

        if let Some(cap_gb) = self.max_memory_gb {
            let est_gb = total_bytes_estimate as f32 / (1024.0 * 1024.0 * 1024.0);
            if est_gb > cap_gb {
                bail!(
                    "weights would dequant to ~{est_gb:.1} GB at F32, exceeds cap {cap_gb:.1} GB"
                );
            }
        }

        crate::capabilities::validate_device(&cfg, device, self.packed_weights)?;

        let path_str = weights_path
            .to_str()
            .ok_or_else(|| anyhow!("non-utf8 weights path"))?;
        let generator = if self.packed_weights {
            None
        } else {
            Some(
                GemmaGenerator::from_path(cfg.clone(), path_str, device)?
                    .with_prefill_cache(2)
                    .with_decode_cache(max_seq + 64),
            )
        };

        let packed = if self.packed_weights {
            if !matches!(format, WeightFormat::Gguf) {
                bail!(
                    "packed_weights(true) requires a .gguf file; got {:?} for {:?}",
                    format,
                    weights_path
                );
            }
            eprintln!(
                "[gemma-runner] packed_weights=true — compiling prefill graph with \
                 Op::DequantMatMul on {device:?}"
            );
            Some(GemmaPackedForward::build(
                &cfg,
                &weights_path,
                max_seq,
                device,
            )?)
        } else {
            None
        };

        Ok(GemmaRunner {
            generator,
            cfg,
            sample,
            stream,
            device,
            packed,
        })
    }
}

struct GemmaPackedForward {
    compiled: rlx_runtime::CompiledGraph,
    seq: usize,
}

impl GemmaPackedForward {
    fn build(cfg: &GemmaConfig, weights_path: &Path, seq: usize, device: Device) -> Result<Self> {
        use crate::build_gemma_graph_sized_packed;
        let mut loader = GgufLoader::from_file(
            weights_path
                .to_str()
                .ok_or_else(|| anyhow!("non-utf8 weights path"))?,
        )?;
        let mut packed = std::collections::HashMap::new();
        // `last_logits_only=false` so the runner can extract the row
        // at the real prompt's last index. Same fix as rlx-qwen3 /
        // rlx-llama32 — see `predict_logits` for the rationale.
        let (graph, params) =
            build_gemma_graph_sized_packed(cfg, &mut loader, 1, seq, true, false, &mut packed)?;
        let opts = rlx_core::flow_bridge::compile_options_for_profile(
            &CompileProfile::gemma_prefill(),
            device,
        );
        let mut compiled = Session::new(device).compile_with(graph, &opts);
        for (name, data) in &params {
            compiled.set_param(name, data);
        }
        for (name, (bytes, _scheme, _shape)) in &packed {
            compiled.set_param_typed(name, bytes, rlx_ir::DType::U8);
        }
        Ok(Self { compiled, seq })
    }
}

pub struct GemmaRunner {
    generator: Option<GemmaGenerator>,
    cfg: GemmaConfig,
    sample: SampleOpts,
    stream: bool,
    device: Device,
    packed: Option<GemmaPackedForward>,
}

impl GemmaRunner {
    pub fn builder() -> GemmaRunnerBuilder {
        GemmaRunnerBuilder::default()
    }

    pub fn config(&self) -> &GemmaConfig {
        &self.cfg
    }

    pub fn device(&self) -> Device {
        self.device
    }

    /// Single prefill forward; returns last-position logits `[vocab]`.
    pub fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
        if let Some(p) = self.packed.as_mut() {
            // Zero-pad after the real prompt + extract logits at the
            // real last index. Same fix as rlx-qwen3 / rlx-llama32.
            let n = prompt_ids.len().min(p.seq);
            let last = n.saturating_sub(1);
            let mut padded = vec![0u32; p.seq];
            for (i, &t) in prompt_ids.iter().take(p.seq).enumerate() {
                padded[i] = t;
            }
            let ids_f32: Vec<f32> = padded.iter().map(|&i| i as f32).collect();
            let out = p.compiled.run(&[("input_ids", ids_f32.as_slice())]);
            let logits = out
                .into_iter()
                .next()
                .ok_or_else(|| anyhow!("packed forward returned no output"))?;
            let vocab = self.cfg.vocab_size;
            let expected = p.seq * vocab;
            if logits.len() < expected {
                bail!("logits short: {} < {expected}", logits.len());
            }
            let start = last * vocab;
            return Ok(logits[start..start + vocab].to_vec());
        }
        let generator = self
            .generator
            .as_mut()
            .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
        generator.prefill_get_last_logits(prompt_ids)
    }

    pub fn generate_packed(
        &mut self,
        prompt_ids: &[u32],
        n_new: usize,
        mut on_token: impl FnMut(u32),
    ) -> Result<Vec<u32>> {
        if self.packed.is_none() {
            bail!("generate_packed() only works in packed_weights(true) mode");
        }
        let mut history: Vec<u32> = prompt_ids.to_vec();
        let mut out = Vec::with_capacity(n_new);
        for _ in 0..n_new {
            let logits = self.predict_logits(&history)?;
            let next = rlx_qwen3::sample_token(&logits, self.sample) as u32;
            on_token(next);
            history.push(next);
            out.push(next);
        }
        Ok(out)
    }

    pub fn generate(
        &mut self,
        prompt_ids: &[u32],
        n_new: usize,
        mut on_token: impl FnMut(u32),
    ) -> Result<Vec<u32>> {
        if self.packed.is_some() {
            return self.generate_packed(prompt_ids, n_new, on_token);
        }
        let generator = self
            .generator
            .as_mut()
            .ok_or_else(|| anyhow!("F32 generator unavailable in packed_weights mode"))?;
        generator.prefill(prompt_ids);
        let tokens = if self.stream {
            generator.generate_cached_with(n_new, self.sample, &mut on_token)?
        } else {
            let toks = generator.generate_cached(n_new, self.sample)?;
            for &t in &toks {
                on_token(t);
            }
            toks
        };
        Ok(tokens)
    }
}

impl LmRunner for GemmaRunner {
    fn family(&self) -> &'static str {
        "gemma"
    }
    fn vocab_size(&self) -> usize {
        self.config().vocab_size
    }
    fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>> {
        GemmaRunner::predict_logits(self, prompt_ids)
    }
    fn generate(
        &mut self,
        prompt_ids: &[u32],
        n_new: usize,
        on_token: &mut dyn FnMut(u32) -> bool,
    ) -> Result<Vec<u32>> {
        // Inherent generate ignores stop signal — drop the bool.
        GemmaRunner::generate(self, prompt_ids, n_new, |tok| {
            let _ = on_token(tok);
        })
    }
}

fn load_gemma_gguf_config(
    path: &Path,
    override_src: Option<&GemmaConfigSource>,
) -> Result<(GemmaConfig, u64)> {
    let raw = assert_gguf_family(path, GgufModelFamily::Gemma)?;
    let cfg = match override_src {
        Some(GemmaConfigSource::Explicit(c)) => c.clone(),
        Some(GemmaConfigSource::JsonFile(p)) => {
            GemmaConfig::from_file(p).with_context(|| format!("reading override config {p:?}"))?
        }
        Some(GemmaConfigSource::Embedded) | None => gemma_cfg_from_gguf(&raw)?,
    };
    Ok((cfg, gguf_f32_bytes_estimate(&raw)))
}

fn load_gemma_safetensors_config(
    path: &Path,
    override_src: Option<&GemmaConfigSource>,
) -> Result<(GemmaConfig, u64)> {
    let cfg_path = match override_src {
        Some(GemmaConfigSource::Explicit(c)) => {
            return Ok((c.clone(), default_st_size_estimate(path)));
        }
        Some(GemmaConfigSource::JsonFile(p)) => p.clone(),
        Some(GemmaConfigSource::Embedded) => {
            bail!("ConfigSource::Embedded only valid for GGUF; pass JsonFile for safetensors")
        }
        None => path
            .parent()
            .ok_or_else(|| anyhow!("weights path has no parent dir"))?
            .join("config.json"),
    };
    let cfg = GemmaConfig::from_file(&cfg_path)
        .with_context(|| format!("reading config {cfg_path:?}"))?;
    Ok((cfg, default_st_size_estimate(path)))
}

fn default_st_size_estimate(path: &Path) -> u64 {
    std::fs::metadata(path).map(|m| m.len()).unwrap_or(0)
}