rlx-gemma 0.2.4

Gemma / Gemma 2 causal LMs for RLX
Documentation

rlx-gemma

Google Gemma, Gemma 2, Gemma 3, and Gemma 4 causal LMs on RLX. Single graph layer, every backend (CPU / Metal / MLX / CUDA / ROCm / wgpu / Vulkan), no Python.

Versions and features

Family Status Notes
Gemma 1 (2b, 7b) Reference dense decoder.
Gemma 2 (9b, 27b) Pre/post FFN RMS norms, attention soft-cap, logit soft-cap, alternating sliding-window mask.
Gemma 3 (1b27b) Stride-6 sliding-window pattern (5 sliding + 1 full causal).
Gemma 4 (E2B, E4B, 12B unified) ✅ LM, ⚠️ multimodal See below.

Gemma 4 unified (google/gemma-4-12B)

The 12B "unified" model is encoder-free: image patches and audio frames are projected directly into the LM embedding space. This crate implements:

  • Per-layer attention shape. Sliding layers use head_dim=256, num_kv_heads=8. Full-attention layers use global_head_dim=512, num_global_key_value_heads=1. KV cache shapes diverge per layer in the decode flow.
  • Split RoPE. Sliding layers run plain RoPE at theta=10_000. Full-attention layers run proportional RoPE with partial_rotary_factor=0.25 at theta=1_000_000. Both prefill (static) and decode (static + dynamic) paths emit two RoPE tables and dispatch per layer.
  • attention_k_eq_v. The V projection is aliased to K (pre-RoPE), so V matmul and v_proj.weight load are skipped at runtime.
  • final_logit_softcapping=30.0 wired through the existing LogitSoftcap stage.
  • Unified token vocab (vocab_size=262_144) with image_token_id=258_880, audio_token_id=258_881, video_token_id=258_884, and the surrounding boi/eoi/boa/eoa markers.

Multimodal status:

  • ✅ Standalone vision + audio projector graphs (matmul + RMS norm + LM projection). Numerically verified on CPU, Metal, MLX, and wgpu.
  • ✅ Preprocessing: JPEG/PNG → patches (ImageNet mean/std by default), WAV → 16 kHz mono f32 (linear resampler), prompt template → media-placeholder token stream.
  • ✅ Safetensors loader for vision_tower.* / audio_tower.* weights (F32 / F16 / BF16).
  • ⚠️ Vision pool-to-soft-tokens uses a placeholder matmul; the reference cross-attention "Q-Former-style" reduction from P patches to num_soft_tokens is not yet wired (output shape is [B, P, output_proj_dims] today). Fix-in-place when reference projector weights are pinned — see [build_vision_projection_hir].

Backend coverage

cargo build -p rlx-gemma --features <feature>. Feature → backend:

Feature Device it enables
(default) CPU
metal Apple Metal
mlx Apple MLX
cuda NVIDIA CUDA
rocm AMD ROCm
gpu wgpu (Metal/Vulkan/DX12)
vulkan Vulkan compute
apple-silicon metal + mlx + gpu
all-backends every device

Verified parity (M-series macOS, apple-silicon)

cargo test -p rlx-gemma --features apple-silicon
# 106 / 106 pass:
#   54 lib  (config, flow, multimodal, runner)
#    8 gemma4_backend_parity         (projector vs CPU on Metal/MLX/wgpu)
#   28 gemma4_lm_backend_parity      (prefill + Gemma 1/2/3 regression)
#    4 gemma4_decode_backend_parity  (full per-layer KV divergence)
#    4 gemma4_decode_throughput      (tok/s regression detection)
#    8 gemma4_reference_fixture      (3 auto-skip without RLX_GEMMA4_FIXTURE; 5 offline flow tests)

Numerical parity vs CPU on the Gemma-4 LM graph (k_eq_v + split RoPE + per-layer head_dim variation, tiny 6-layer config):

Backend LM max|Δ| Projector max|Δ|
MLX 8.6e-8 (essentially bit-exact) < 1e-3
wgpu 1.2e-6 < 1e-3
Metal causal 1.1e-6 (precise scalar sgemm) < 1e-3
Metal sliding 1.1e-6 (SDPA mask_kind=4 wired) < 1e-3
CUDA / ROCm / Vulkan compile-time gated, not run on this host compile-time gated

Metal precision: a knob, not a floor

Apple Silicon's simdgroup_float8x8 matmul units use reduced-precision internal accumulators (~fp16 class for the tensor-multiply step) — fine for production inference where the resulting ~1e-3 relative error is bounded by stable softmax + layer norms, but visible (~1e-1 absolute) on tiny parity-test logits. Set RLX_METAL_SGEMM_VARIANT=naive (or RLX_METAL_PRECISE=1) to route through the scalar fp32 sgemm kernel and recover full CPU parity. The parity tests do this automatically.

This work also tightened rlx-metal's SDPA kernels to use precise::sqrt + precise::exp (avoiding the relaxed-precision metal:: namespace defaults; see kernels.rs::sdpa{,_long,_h}) and wired MaskKind::SlidingWindow(w) end-to-end so Gemma 3 / 4 alternating attention now runs on Metal where it previously panicked.

  • CUDA / ROCm / Vulkan have compile-time parity (is_available()-gated tests auto-skip without a live driver) and will execute when a host with the matching backend runs cargo test --features cuda (or rocm, vulkan).

Examples

Load a config and run the text LM

use rlx_gemma::{GemmaConfig, GemmaRunner, GemmaConfigSource};
use rlx_qwen3::SampleOpts;
use rlx_runtime::Device;

let mut runner = GemmaRunner::builder()
    .weights("/path/to/gemma-4-12b/model.safetensors")
    .device(Device::Metal)
    .max_seq(512)
    .sample(SampleOpts::greedy())
    .build()?;

let prompt_ids = rlx_gemma::encode_prompt_auto(
    runner.config_ref(),  // resolved weights path
    None,                  // tokenizer auto-found
    "Explain attention in one paragraph.",
)?;
runner.generate(&prompt_ids, 64, |tok| print!("{tok} "))?;

Multimodal projector

use rlx_gemma::{GemmaMultimodalConfig, GemmaMultimodalRunner, MultimodalWeights};
use rlx_runtime::Device;

let cfg = GemmaMultimodalConfig::from_file("/path/to/gemma-4-12b/config.json")?;
let weights = MultimodalWeights::from_safetensors("/path/to/gemma-4-12b/model.safetensors")?;

let mut mm = GemmaMultimodalRunner::new(cfg, /* lm_hidden */ 3840, Device::Metal, None, None)?;

// JPEG/PNG → ImageNet-normalized patches → projector → soft tokens.
let image_soft = mm.project_image_file("photo.jpg", &weights, /* max_side_patches */ 32)?;

// WAV → 16 kHz mono → frames → projector → audio soft tokens.
let audio_soft = mm.project_audio_file("clip.wav", &weights)?;

// Tokenize "describe <image> while listening to <audio>" with the right
// number of placeholder tokens for each slot.
let encode = |s: &str| rlx_gemma::encode_prompt_auto("...".as_ref(), None, s);
let ids = mm.tokenize_prompt(
    "describe <image> while listening to <audio>",
    /* image_count */ 1,
    &[audio_samples_len],  // total samples in the audio clip
    encode,
)?;

// After embedding the LM token stream, splice the media rows in.
mm.fuse_text_and_media(&mut text_embeds, &ids, &image_soft, &audio_soft)?;
// `text_embeds` now feeds the first LM decoder block.

Inspect per-layer dispatch

let cfg = GemmaConfig::from_file("/path/to/gemma-4-12b/config.json")?;
assert_eq!(cfg.arch, rlx_gemma::GemmaArch::Gemma4);

// Sliding layer (layer index 0).
assert_eq!(cfg.layer_head_dim(0), 256);
assert_eq!(cfg.layer_num_kv_heads(0), 8);
assert_eq!(cfg.layer_n_rot(0), 256);
assert!((cfg.layer_rope_theta(0) - 10_000.0).abs() < 1e-3);

// Full-attention layer (every 6th, 1-indexed).
assert!(cfg.is_full_attention_layer(5));
assert_eq!(cfg.layer_head_dim(5), 512);
assert_eq!(cfg.layer_num_kv_heads(5), 1);
assert_eq!(cfg.layer_n_rot(5), 128);   // 512 * 0.25
assert!((cfg.layer_rope_theta(5) - 1_000_000.0).abs() < 1e-3);

Architecture notes

Per-layer attention without new IR ops

Every Gemma 4 quirk is expressible with the existing IR op surface:

  • Per-layer head_dim / num_kv_heads ⇒ per-layer gemma_attn_spec(...) with different shapes. Each backend already lowers MatMul at arbitrary shapes.
  • Partial RoPE (p-RoPE) ⇒ Op::Rope { head_dim, n_rot } with n_rot < head_dim. This was added for Qwen3.5 MRoPE; every backend already implements it.
  • Split RoPE tables ⇒ a second RopeTablesStage::param_named("global", …) publishes (cos, sin) into FlowState::named under the global slot. Self-attn and decode emitters opt in via SelfAttnPrefillSpec::rope_table / GemmaDecodeLayerSpec::rope_table.
  • attention_k_eq_v ⇒ the emitter skips the V matmul and the V weight load. Backend kernels see only Q and K — V aliases K's HIR node id.

Multimodal pipeline

                                    +----------------+
   prompt template  ───────────────►│ tokenize_with_ │
   "<image> see <audio>"            │     media      │──► [u32] token stream
                                    +----------------+
                                              │
                                              ▼
 photo.jpg ──► load_image_patches  ──► [B, P, F]
                                              │
                                              ▼  build_vision_projection_graph
                                       compiled vision projector
                                              │
                                              ▼
 clip.wav ──► load_wav + framing   ──► [B, frames, samples_per_token]
                                              │
                                              ▼  build_audio_projection_graph
                                       compiled audio projector
                                              │
                                              ▼
 LM embed(token_ids) ────────────► fuse_multimodal_embeddings
                                              │
                                              ▼
                                       LM decoder blocks

Each box is a separately-runnable graph or pure-CPU function, so callers can swap in their own image decoder, tokenizer, or fusion policy without touching the rest.

Weight formats

Safetensors

GemmaRunner::builder().weights("model.safetensors") accepts F32, F16, and BF16 safetensors directly. The runner uses rlx-models-core::WeightMap which decodes the source dtype to F32 at load time. Sharded checkpoints (model-*.safetensors) are loaded through from_safetensors_dir.

MultimodalWeights::from_safetensors carries the same decoder for the vision_tower.* / audio_tower.* projector weights.

GGUF (quantized)

Gemma 4 GGUF support is wired through rlx-models-core's GgufModelFamily::Gemma — accepts gemma, gemma2, gemma3, gemma3n, gemma4, gemma4moe, and gemma4_unified arch tags. GemmaConfig::from_gguf reads metadata keys under the matching arch prefix.

The standard GGUF path (GemmaRunner::builder().weights("model.gguf")) dequantizes weights to F32 at load time, then runs F32 inference. This works for Q8_0, Q4_K_M, Q5_K_S, and other GGUF tensor formats — the dequant table is in rlx-gguf. Memory cost is full F32 (~48 GB for Gemma 4 12B); this format choice trades memory for implementation simplicity.

Packed quantized inference

GemmaRunner::builder().packed_weights(true) is the memory-saving path — weights stay in their GGUF-quantized layout and Op::DequantMatMul runs on-the-fly during matmul. The packed builder covers every Gemma 1/2/3/4 path the F32 builder does: per-layer head_dim / num_kv_heads / n_rot, split RoPE, attention_k_eq_v, attention soft-cap, final logit soft-cap, alternating sliding-window mask. Q/K/V/O and the gate/up/down MLP projections all route through DequantMatMul when the source tensor is K-quantized; tensors that the GGUF loader returns as F32 (RMS norm gammas, embed_tokens) stay on the regular MatMul path.

Multimodal weights (separate GGUF)

llama.cpp ships multimodal projector weights as a separate GGUF file (the mmproj.gguf companion next to the LM model.gguf — used for LLaVA, MiniCPM-V, Gemma 3 vision, and the upcoming Gemma 4 unified). Both loaders are wired:

// HF safetensors path (one tensor decoder, F32/F16/BF16):
let w = MultimodalWeights::from_safetensors("model.safetensors")?;

// llama.cpp mmproj.gguf path (decodes any K-quant rlx-gguf supports):
let w = MultimodalWeights::from_mmproj_gguf("mmproj.gguf")?;

from_mmproj_gguf drains every tensor whose name starts with vision_tower. / audio_tower. and dequantizes to F32 at load. The runner doesn't care which loader supplied the bytes — both produce the same MultimodalWeights map.

Quantization story summary

Format Memory Status
F32 safetensors 48 GB (Gemma 4 12B)
F16 / BF16 safetensors 24 GB ✅ (dequant to F32 at load)
GGUF Q8_0 / Q4_K_M (default load) 6–13 GB on disk → 48 GB in RAM ✅ (dequant at load)
GGUF + packed_weights(true) 6–13 GB in RAM ✅ (Op::DequantMatMul per projection)
mmproj.gguf (multimodal) depends on K-quant ✅ via MultimodalWeights::from_mmproj_gguf

See also

  • Main repo README — Gemma entry under What's here
  • AGENTS.mdjust recipes and CI commands

License

GPL-3.0. See workspace LICENSE.