rlx-gemma
Google Gemma, Gemma 2, Gemma 3, and Gemma 4 causal LMs on RLX. Single graph layer, every backend (CPU / Metal / MLX / CUDA / ROCm / wgpu / Vulkan), no Python.
Versions and features
| Family | Status | Notes |
|---|---|---|
Gemma 1 (2b, 7b) |
✅ | Reference dense decoder. |
Gemma 2 (9b, 27b) |
✅ | Pre/post FFN RMS norms, attention soft-cap, logit soft-cap, alternating sliding-window mask. |
Gemma 3 (1b–27b) |
✅ | Stride-6 sliding-window pattern (5 sliding + 1 full causal). |
Gemma 4 (E2B, E4B, 12B unified) |
✅ LM, ⚠️ multimodal | See below. |
Gemma 4 unified (google/gemma-4-12B)
The 12B "unified" model is encoder-free: image patches and audio frames are projected directly into the LM embedding space. This crate implements:
- Per-layer attention shape. Sliding layers use
head_dim=256,num_kv_heads=8. Full-attention layers useglobal_head_dim=512,num_global_key_value_heads=1. KV cache shapes diverge per layer in the decode flow. - Split RoPE. Sliding layers run plain RoPE at
theta=10_000. Full-attention layers run proportional RoPE withpartial_rotary_factor=0.25attheta=1_000_000. Both prefill (static) and decode (static + dynamic) paths emit two RoPE tables and dispatch per layer. attention_k_eq_v. The V projection is aliased to K (pre-RoPE), so V matmul andv_proj.weightload are skipped at runtime.final_logit_softcapping=30.0wired through the existingLogitSoftcapstage.- Unified token vocab (
vocab_size=262_144) withimage_token_id=258_880,audio_token_id=258_881,video_token_id=258_884, and the surroundingboi/eoi/boa/eoamarkers.
Multimodal status:
- ✅ Standalone vision + audio projector graphs (matmul + RMS norm + LM projection). Numerically verified on CPU, Metal, MLX, and wgpu.
- ✅ Preprocessing: JPEG/PNG → patches (ImageNet mean/std by default), WAV → 16 kHz mono f32 (linear resampler), prompt template → media-placeholder token stream.
- ✅ Safetensors loader for
vision_tower.*/audio_tower.*weights (F32 / F16 / BF16). - ⚠️ Vision pool-to-soft-tokens uses a placeholder matmul; the reference
cross-attention "Q-Former-style" reduction from P patches to
num_soft_tokensis not yet wired (output shape is[B, P, output_proj_dims]today). Fix-in-place when reference projector weights are pinned — see [build_vision_projection_hir].
Backend coverage
cargo build -p rlx-gemma --features <feature>. Feature → backend:
| Feature | Device it enables |
|---|---|
| (default) | CPU |
metal |
Apple Metal |
mlx |
Apple MLX |
cuda |
NVIDIA CUDA |
rocm |
AMD ROCm |
gpu |
wgpu (Metal/Vulkan/DX12) |
vulkan |
Vulkan compute |
apple-silicon |
metal + mlx + gpu |
all-backends |
every device |
Verified parity (M-series macOS, apple-silicon)
cargo test -p rlx-gemma --features apple-silicon
# 106 / 106 pass:
# 54 lib (config, flow, multimodal, runner)
# 8 gemma4_backend_parity (projector vs CPU on Metal/MLX/wgpu)
# 28 gemma4_lm_backend_parity (prefill + Gemma 1/2/3 regression)
# 4 gemma4_decode_backend_parity (full per-layer KV divergence)
# 4 gemma4_decode_throughput (tok/s regression detection)
# 8 gemma4_reference_fixture (3 auto-skip without RLX_GEMMA4_FIXTURE; 5 offline flow tests)
Numerical parity vs CPU on the Gemma-4 LM graph (k_eq_v + split RoPE + per-layer head_dim variation, tiny 6-layer config):
| Backend | LM max|Δ| |
Projector max|Δ| |
|---|---|---|
| MLX | 8.6e-8 (essentially bit-exact) | < 1e-3 |
| wgpu | 1.2e-6 | < 1e-3 |
| Metal causal | 1.1e-6 (precise scalar sgemm) | < 1e-3 |
| Metal sliding | 1.1e-6 (SDPA mask_kind=4 wired) |
< 1e-3 |
| CUDA / ROCm / Vulkan | compile-time gated, not run on this host | compile-time gated |
Metal precision: a knob, not a floor
Apple Silicon's simdgroup_float8x8 matmul units use reduced-precision
internal accumulators (~fp16 class for the tensor-multiply step) — fine
for production inference where the resulting ~1e-3 relative error is
bounded by stable softmax + layer norms, but visible (~1e-1 absolute)
on tiny parity-test logits. Set RLX_METAL_SGEMM_VARIANT=naive (or
RLX_METAL_PRECISE=1) to route through the scalar fp32 sgemm kernel
and recover full CPU parity. The parity tests do this automatically.
This work also tightened rlx-metal's SDPA kernels to use
precise::sqrt + precise::exp (avoiding the relaxed-precision
metal:: namespace defaults; see kernels.rs::sdpa{,_long,_h}) and
wired MaskKind::SlidingWindow(w) end-to-end so Gemma 3 / 4 alternating
attention now runs on Metal where it previously panicked.
- CUDA / ROCm / Vulkan have compile-time parity (
is_available()-gated tests auto-skip without a live driver) and will execute when a host with the matching backend runscargo test --features cuda(orrocm,vulkan).
Examples
Load a config and run the text LM
use ;
use SampleOpts;
use Device;
let mut runner = builder
.weights
.device
.max_seq
.sample
.build?;
let prompt_ids = encode_prompt_auto?;
runner.generate?;
Multimodal projector
use ;
use Device;
let cfg = from_file?;
let weights = from_safetensors?;
let mut mm = new?;
// JPEG/PNG → ImageNet-normalized patches → projector → soft tokens.
let image_soft = mm.project_image_file?;
// WAV → 16 kHz mono → frames → projector → audio soft tokens.
let audio_soft = mm.project_audio_file?;
// Tokenize "describe <image> while listening to <audio>" with the right
// number of placeholder tokens for each slot.
let encode = ;
let ids = mm.tokenize_prompt?;
// After embedding the LM token stream, splice the media rows in.
mm.fuse_text_and_media?;
// `text_embeds` now feeds the first LM decoder block.
Inspect per-layer dispatch
let cfg = from_file?;
assert_eq!;
// Sliding layer (layer index 0).
assert_eq!;
assert_eq!;
assert_eq!;
assert!;
// Full-attention layer (every 6th, 1-indexed).
assert!;
assert_eq!;
assert_eq!;
assert_eq!; // 512 * 0.25
assert!;
Architecture notes
Per-layer attention without new IR ops
Every Gemma 4 quirk is expressible with the existing IR op surface:
- Per-layer
head_dim/num_kv_heads⇒ per-layergemma_attn_spec(...)with different shapes. Each backend already lowersMatMulat arbitrary shapes. - Partial RoPE (p-RoPE) ⇒
Op::Rope { head_dim, n_rot }withn_rot < head_dim. This was added for Qwen3.5 MRoPE; every backend already implements it. - Split RoPE tables ⇒ a second
RopeTablesStage::param_named("global", …)publishes(cos, sin)intoFlowState::namedunder theglobalslot. Self-attn and decode emitters opt in viaSelfAttnPrefillSpec::rope_table/GemmaDecodeLayerSpec::rope_table. attention_k_eq_v⇒ the emitter skips the V matmul and the V weight load. Backend kernels see only Q and K — V aliases K's HIR node id.
Multimodal pipeline
+----------------+
prompt template ───────────────►│ tokenize_with_ │
"<image> see <audio>" │ media │──► [u32] token stream
+----------------+
│
▼
photo.jpg ──► load_image_patches ──► [B, P, F]
│
▼ build_vision_projection_graph
compiled vision projector
│
▼
clip.wav ──► load_wav + framing ──► [B, frames, samples_per_token]
│
▼ build_audio_projection_graph
compiled audio projector
│
▼
LM embed(token_ids) ────────────► fuse_multimodal_embeddings
│
▼
LM decoder blocks
Each box is a separately-runnable graph or pure-CPU function, so callers can swap in their own image decoder, tokenizer, or fusion policy without touching the rest.
Weight formats
Safetensors
GemmaRunner::builder().weights("model.safetensors") accepts F32, F16,
and BF16 safetensors directly. The runner uses
rlx-models-core::WeightMap which decodes the source dtype to F32 at
load time. Sharded checkpoints (model-*.safetensors) are loaded
through from_safetensors_dir.
MultimodalWeights::from_safetensors carries the same decoder for the
vision_tower.* / audio_tower.* projector weights.
GGUF (quantized)
Gemma 4 GGUF support is wired through rlx-models-core's
GgufModelFamily::Gemma — accepts gemma, gemma2, gemma3,
gemma3n, gemma4, gemma4moe, and gemma4_unified arch tags.
GemmaConfig::from_gguf reads metadata keys under the matching arch
prefix.
The standard GGUF path (GemmaRunner::builder().weights("model.gguf"))
dequantizes weights to F32 at load time, then runs F32 inference.
This works for Q8_0, Q4_K_M, Q5_K_S, and other GGUF tensor
formats — the dequant table is in rlx-gguf. Memory cost is full F32
(~48 GB for Gemma 4 12B); this format choice trades memory for
implementation simplicity.
Packed quantized inference
GemmaRunner::builder().packed_weights(true) is the memory-saving
path — weights stay in their GGUF-quantized layout and
Op::DequantMatMul runs on-the-fly during matmul. The packed builder
covers every Gemma 1/2/3/4 path the F32 builder does: per-layer
head_dim / num_kv_heads / n_rot, split RoPE, attention_k_eq_v,
attention soft-cap, final logit soft-cap, alternating sliding-window
mask. Q/K/V/O and the gate/up/down MLP projections all route through
DequantMatMul when the source tensor is K-quantized; tensors that
the GGUF loader returns as F32 (RMS norm gammas, embed_tokens) stay
on the regular MatMul path.
Multimodal weights (separate GGUF)
llama.cpp ships multimodal projector weights as a separate GGUF
file (the mmproj.gguf companion next to the LM model.gguf — used
for LLaVA, MiniCPM-V, Gemma 3 vision, and the upcoming Gemma 4
unified). Both loaders are wired:
// HF safetensors path (one tensor decoder, F32/F16/BF16):
let w = from_safetensors?;
// llama.cpp mmproj.gguf path (decodes any K-quant rlx-gguf supports):
let w = from_mmproj_gguf?;
from_mmproj_gguf drains every tensor whose name starts with
vision_tower. / audio_tower. and dequantizes to F32 at load. The
runner doesn't care which loader supplied the bytes — both produce
the same MultimodalWeights map.
Quantization story summary
| Format | Memory | Status |
|---|---|---|
| F32 safetensors | 48 GB (Gemma 4 12B) | ✅ |
| F16 / BF16 safetensors | 24 GB | ✅ (dequant to F32 at load) |
GGUF Q8_0 / Q4_K_M (default load) |
6–13 GB on disk → 48 GB in RAM | ✅ (dequant at load) |
GGUF + packed_weights(true) |
6–13 GB in RAM | ✅ (Op::DequantMatMul per projection) |
| mmproj.gguf (multimodal) | depends on K-quant | ✅ via MultimodalWeights::from_mmproj_gguf |
See also
License
GPL-3.0. See workspace LICENSE.