rlx-gemma 0.2.0

Gemma / Gemma 2 causal LMs for RLX
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

//! Gemma graph builders — thin wrappers over [`crate::flow::GemmaFlow`].

use crate::config::GemmaConfig;
use anyhow::{Result, anyhow};
use rlx_core::weight_loader::WeightLoader;
use rlx_ir::Graph;
use rlx_ir::hir::HirModule;
use std::collections::HashMap;

pub fn build_gemma_graph_sized(
    cfg: &GemmaConfig,
    weights: &mut dyn WeightLoader,
    batch: usize,
    seq: usize,
    with_lm_head: bool,
    with_kv_outputs: bool,
) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
    let opts = crate::flow::GemmaPrefillOpts {
        batch,
        seq,
        dynamic_seq: false,
        with_lm_head,
        with_kv_outputs,
        last_logits_only: false,
        profile: None,
    };
    rlx_core::flow_util::graph_from_built(crate::flow::build_gemma_prefill_built(
        cfg, weights, &opts,
    )?)
}

pub fn build_gemma_graph_sized_last_logits(
    cfg: &GemmaConfig,
    weights: &mut dyn WeightLoader,
    batch: usize,
    seq: usize,
    with_kv_outputs: bool,
) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
    let opts = crate::flow::GemmaPrefillOpts {
        batch,
        seq,
        dynamic_seq: false,
        with_lm_head: true,
        with_kv_outputs,
        last_logits_only: true,
        profile: None,
    };
    rlx_core::flow_util::graph_from_built(crate::flow::build_gemma_prefill_built(
        cfg, weights, &opts,
    )?)
}

pub fn build_gemma_prefill_hir_dynamic_ext(
    cfg: &GemmaConfig,
    weights: &mut dyn WeightLoader,
    batch: usize,
    max_seq: usize,
    with_kv_outputs: bool,
) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
    validate_cfg(cfg)?;
    if batch != 1 {
        return Err(anyhow!("gemma: dynamic_seq prefill requires batch=1"));
    }
    let opts = crate::flow::GemmaPrefillOpts {
        batch,
        seq: max_seq,
        dynamic_seq: true,
        with_lm_head: true,
        with_kv_outputs,
        last_logits_only: true,
        profile: None,
    };
    crate::flow::build_gemma_prefill_flow(cfg, weights, &opts)
}

pub fn build_gemma_decode_graph_sized(
    cfg: &GemmaConfig,
    weights: &mut dyn WeightLoader,
    batch: usize,
    past_seq: usize,
) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
    build_gemma_decode_graph_sized_ext(cfg, weights, batch, past_seq, false)
}

pub fn build_gemma_decode_graph_sized_ext(
    cfg: &GemmaConfig,
    weights: &mut dyn WeightLoader,
    batch: usize,
    past_seq: usize,
    use_custom_mask: bool,
) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
    let opts = crate::flow::GemmaDecodeOpts {
        batch,
        past_seq,
        dynamic_past: false,
        use_custom_mask,
        profile: None,
    };
    crate::flow::build_gemma_decode_graph(cfg, weights, &opts)
}

pub fn build_gemma_decode_hir_sized(
    cfg: &GemmaConfig,
    weights: &mut dyn WeightLoader,
    batch: usize,
    past_seq: usize,
) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
    build_gemma_decode_hir_sized_ext(cfg, weights, batch, past_seq, false)
}

pub fn build_gemma_decode_hir_sized_ext(
    cfg: &GemmaConfig,
    weights: &mut dyn WeightLoader,
    batch: usize,
    past_seq: usize,
    use_custom_mask: bool,
) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
    validate_cfg(cfg)?;
    let opts = crate::flow::GemmaDecodeOpts {
        batch,
        past_seq,
        dynamic_past: false,
        use_custom_mask,
        profile: None,
    };
    crate::flow::build_gemma_decode_flow(cfg, weights, &opts)
}

pub fn build_gemma_decode_hir_dynamic_ext(
    cfg: &GemmaConfig,
    weights: &mut dyn WeightLoader,
    batch: usize,
    max_past_seq: usize,
) -> Result<(HirModule, HashMap<String, Vec<f32>>)> {
    validate_cfg(cfg)?;
    let opts = crate::flow::GemmaDecodeOpts {
        batch,
        past_seq: max_past_seq,
        dynamic_past: true,
        use_custom_mask: false,
        profile: None,
    };
    crate::flow::build_gemma_decode_flow(cfg, weights, &opts)
}

/// Packed K-quant prefill — not yet implemented; use unpacked weights or flow build.
#[allow(clippy::too_many_arguments)]
pub fn build_gemma_graph_sized_packed(
    cfg: &GemmaConfig,
    _weights: &mut rlx_core::weight_loader::GgufLoader,
    _batch: usize,
    _seq: usize,
    _with_lm_head: bool,
    _last_logits_only: bool,
    _packed: &mut HashMap<String, (Vec<u8>, rlx_ir::quant::QuantScheme, Vec<usize>)>,
) -> Result<(Graph, HashMap<String, Vec<f32>>)> {
    validate_cfg(cfg)?;
    Err(anyhow!(
        "packed gemma prefill graphs are not implemented yet; use standard GGUF drain + GemmaFlow"
    ))
}

fn validate_cfg(cfg: &GemmaConfig) -> Result<()> {
    if !cfg
        .num_attention_heads
        .is_multiple_of(cfg.num_key_value_heads)
    {
        return Err(anyhow!(
            "num_attention_heads ({}) must be divisible by num_key_value_heads ({})",
            cfg.num_attention_heads,
            cfg.num_key_value_heads
        ));
    }
    if cfg.attention_bias {
        return Err(anyhow!("attention_bias=true not yet wired for gemma"));
    }
    Ok(())
}