use std::collections::HashMap;
use std::path::Path;
use mlx_rs::ops;
use mlx_rs::ops::indexing::IndexOp;
use mlx_rs::Array;
use tokenizers::Tokenizer;
use tracing::info;
use super::mlx::{build_qembedding, build_qlinear, QEmbedding, QLinear, QuantConfig};
use crate::tasks::chat_template::ChatTemplate;
use crate::tasks::generate::ToolCall;
use crate::InferenceError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LayerKind {
Sliding,
Full,
}
pub struct Gemma4Config {
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub head_dim: usize,
pub num_key_value_heads: usize,
pub global_head_dim: usize,
pub num_global_key_value_heads: usize,
pub rms_norm_eps: f32,
pub vocab_size: usize,
pub final_logit_softcapping: Option<f32>,
pub layer_types: Vec<LayerKind>,
pub sliding_window: usize,
pub rope_theta_sliding: f32,
pub rope_theta_global: f32,
pub partial_rotary_factor: f32,
pub quant: QuantConfig,
pub eos_token_ids: Vec<u32>,
}
impl Gemma4Config {
fn parse(model_dir: &Path) -> Result<Self, InferenceError> {
let path = model_dir.join("config.json");
let raw = std::fs::read_to_string(&path)
.map_err(|e| InferenceError::InferenceFailed(format!("read {}: {e}", path.display())))?;
let root: serde_json::Value = serde_json::from_str(&raw)
.map_err(|e| InferenceError::InferenceFailed(format!("parse config.json: {e}")))?;
let tc = root.get("text_config").ok_or_else(|| {
InferenceError::InferenceFailed("gemma4 config.json missing text_config".into())
})?;
let u = |v: &serde_json::Value, k: &str, d: u64| -> usize {
v.get(k).and_then(|x| x.as_u64()).unwrap_or(d) as usize
};
let f = |v: &serde_json::Value, k: &str, d: f64| -> f32 {
v.get(k).and_then(|x| x.as_f64()).unwrap_or(d) as f32
};
let num_hidden_layers = u(tc, "num_hidden_layers", 48);
let layer_types: Vec<LayerKind> = match tc.get("layer_types").and_then(|v| v.as_array()) {
Some(arr) => arr
.iter()
.map(|v| match v.as_str() {
Some("full_attention") => LayerKind::Full,
_ => LayerKind::Sliding,
})
.collect(),
None => (0..num_hidden_layers)
.map(|i| {
if (i + 1) % 6 == 0 {
LayerKind::Full
} else {
LayerKind::Sliding
}
})
.collect(),
};
let rope = tc.get("rope_parameters");
let (theta_sliding, theta_global, partial) = match rope {
Some(rp) => {
let sliding = rp
.get("sliding_attention")
.and_then(|s| s.get("rope_theta"))
.and_then(|t| t.as_f64())
.unwrap_or(10_000.0) as f32;
let full = rp.get("full_attention");
let global = full
.and_then(|s| s.get("rope_theta"))
.and_then(|t| t.as_f64())
.unwrap_or(1_000_000.0) as f32;
let prf = full
.and_then(|s| s.get("partial_rotary_factor"))
.and_then(|t| t.as_f64())
.unwrap_or(0.25) as f32;
(sliding, global, prf)
}
None => (10_000.0, 1_000_000.0, 0.25),
};
let q = root.get("quantization").or_else(|| root.get("quantization_config"));
let quant = QuantConfig {
group_size: q.and_then(|v| v.get("group_size")).and_then(|x| x.as_i64()).unwrap_or(64)
as i32,
bits: q.and_then(|v| v.get("bits")).and_then(|x| x.as_i64()).unwrap_or(4) as i32,
mode: q
.and_then(|v| v.get("mode"))
.and_then(|x| x.as_str())
.map(|s| s.to_string()),
};
let softcap = tc
.get("final_logit_softcapping")
.and_then(|x| x.as_f64())
.map(|v| v as f32);
let eos_token_ids: Vec<u32> = match root.get("eos_token_id") {
Some(serde_json::Value::Array(a)) => {
a.iter().filter_map(|v| v.as_u64()).map(|v| v as u32).collect()
}
Some(serde_json::Value::Number(n)) => n.as_u64().map(|v| vec![v as u32]).unwrap_or_default(),
_ => vec![1, 106, 50],
};
Ok(Self {
hidden_size: u(tc, "hidden_size", 3840),
num_hidden_layers,
num_attention_heads: u(tc, "num_attention_heads", 16),
head_dim: u(tc, "head_dim", 256),
num_key_value_heads: u(tc, "num_key_value_heads", 8),
global_head_dim: u(tc, "global_head_dim", 512),
num_global_key_value_heads: u(tc, "num_global_key_value_heads", 1),
rms_norm_eps: f(tc, "rms_norm_eps", 1e-6),
vocab_size: u(tc, "vocab_size", 262_144),
final_logit_softcapping: softcap,
layer_types,
sliding_window: u(tc, "sliding_window", 1024),
rope_theta_sliding: theta_sliding,
rope_theta_global: theta_global,
partial_rotary_factor: partial,
quant,
eos_token_ids,
})
}
}
struct AttnDims {
head_dim: usize,
num_kv_heads: usize,
rope_theta: f32,
rope_dim: usize,
has_v_proj: bool,
}
impl Gemma4Config {
fn attn_dims(&self, kind: LayerKind) -> AttnDims {
match kind {
LayerKind::Sliding => AttnDims {
head_dim: self.head_dim,
num_kv_heads: self.num_key_value_heads,
rope_theta: self.rope_theta_sliding,
rope_dim: self.head_dim, has_v_proj: true,
},
LayerKind::Full => AttnDims {
head_dim: self.global_head_dim,
num_kv_heads: self.num_global_key_value_heads,
rope_theta: self.rope_theta_global,
rope_dim: ((self.global_head_dim as f32) * self.partial_rotary_factor).round()
as usize,
has_v_proj: false,
},
}
}
}
struct GemmaRmsNorm {
weight: Array,
eps: f32,
}
impl GemmaRmsNorm {
fn load(tensors: &HashMap<String, Array>, prefix: &str, eps: f32) -> Result<Self, InferenceError> {
let weight = tensors
.get(&format!("{prefix}.weight"))
.cloned()
.ok_or_else(|| InferenceError::InferenceFailed(format!("missing {prefix}.weight")))?;
Ok(Self { weight, eps })
}
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
ops::multiply(&rms_normalize(x, self.eps)?, &self.weight)
}
}
fn rms_normalize(x: &Array, eps: f32) -> Result<Array, mlx_rs::error::Exception> {
let var = ops::multiply(x, x)?.mean_axes(&[-1], true)?;
let scale = ops::rsqrt(&ops::add(&var, &Array::from_f32(eps))?)?;
ops::multiply(x, &scale)
}
struct GemmaMlp {
gate_proj: QLinear,
up_proj: QLinear,
down_proj: QLinear,
}
impl GemmaMlp {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
quant: Option<&QuantConfig>,
) -> Result<Self, InferenceError> {
Ok(Self {
gate_proj: build_qlinear(tensors, &format!("{prefix}.gate_proj"), quant)?,
up_proj: build_qlinear(tensors, &format!("{prefix}.up_proj"), quant)?,
down_proj: build_qlinear(tensors, &format!("{prefix}.down_proj"), quant)?,
})
}
fn forward(&mut self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let gate = mlx_rs::nn::gelu_approximate(&self.gate_proj.forward(x)?)?;
let up = self.up_proj.forward(x)?;
self.down_proj.forward(&ops::multiply(&gate, &up)?)
}
}
struct Gemma4Attention {
q_proj: QLinear,
k_proj: QLinear,
v_proj: Option<QLinear>,
o_proj: QLinear,
q_norm: GemmaRmsNorm,
k_norm: GemmaRmsNorm,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
is_sliding: bool,
rms_norm_eps: f32,
rope_theta: f32,
prop_freqs: Option<Array>,
kv_cache: Option<(Array, Array)>,
}
impl Gemma4Attention {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
cfg: &Gemma4Config,
kind: LayerKind,
) -> Result<Self, InferenceError> {
let d = cfg.attn_dims(kind);
let quant = Some(&cfg.quant);
let v_proj = if d.has_v_proj {
Some(build_qlinear(tensors, &format!("{prefix}.v_proj"), quant)?)
} else {
None
};
let prop_freqs = if d.rope_dim < d.head_dim {
let half = d.head_dim / 2;
let mut fr: Vec<f32> = Vec::with_capacity(half);
let mut i = 0;
while i < d.rope_dim {
fr.push(d.rope_theta.powf(i as f32 / d.head_dim as f32));
i += 2;
}
while fr.len() < half {
fr.push(f32::INFINITY);
}
Some(Array::from_slice(&fr, &[half as i32]))
} else {
None
};
Ok(Self {
q_proj: build_qlinear(tensors, &format!("{prefix}.q_proj"), quant)?,
k_proj: build_qlinear(tensors, &format!("{prefix}.k_proj"), quant)?,
v_proj,
o_proj: build_qlinear(tensors, &format!("{prefix}.o_proj"), quant)?,
q_norm: GemmaRmsNorm::load(tensors, &format!("{prefix}.q_norm"), cfg.rms_norm_eps)?,
k_norm: GemmaRmsNorm::load(tensors, &format!("{prefix}.k_norm"), cfg.rms_norm_eps)?,
num_heads: cfg.num_attention_heads,
num_kv_heads: d.num_kv_heads,
head_dim: d.head_dim,
is_sliding: kind == LayerKind::Sliding,
rms_norm_eps: cfg.rms_norm_eps,
rope_theta: d.rope_theta,
prop_freqs,
kv_cache: None,
})
}
fn clear_cache(&mut self) {
self.kv_cache = None;
}
fn truncate_cache(&mut self, len: i32) {
if let Some((k, v)) = self.kv_cache.take() {
if k.shape()[2] > len {
let k = k.index((.., .., ..len, ..));
let v = v.index((.., .., ..len, ..));
self.kv_cache = Some((k, v));
} else {
self.kv_cache = Some((k, v));
}
}
}
fn forward(
&mut self,
x: &Array,
offset: i32,
sliding_mask: &Array,
global_mask: Option<&Array>,
) -> Result<Array, mlx_rs::error::Exception> {
let s = x.shape();
let (batch, seq_len) = (s[0], s[1]);
let hd = self.head_dim as i32;
let q = self.q_proj.forward(x)?;
let k_raw = self.k_proj.forward(x)?;
let v_raw = match &mut self.v_proj {
Some(vp) => vp.forward(x)?,
None => k_raw.clone(), };
let to_heads = |t: &Array, heads: usize| -> Result<Array, mlx_rs::error::Exception> {
let r = ops::reshape(t, &[batch, seq_len, heads as i32, hd])?;
ops::transpose_axes(&r, &[0, 2, 1, 3])
};
let q = to_heads(&q, self.num_heads)?;
let k = to_heads(&k_raw, self.num_kv_heads)?;
let v = to_heads(&v_raw, self.num_kv_heads)?;
let q = self.q_norm.forward(&q)?;
let k = self.k_norm.forward(&k)?;
let v = rms_normalize(&v, self.rms_norm_eps)?;
let q = self.apply_rope(&q, offset)?;
let k = self.apply_rope(&k, offset)?;
let (k, v) = match self.kv_cache.take() {
Some((ck, cv)) => (
ops::concatenate_axis(&[&ck, &k], 2)?,
ops::concatenate_axis(&[&cv, &v], 2)?,
),
None => (k, v),
};
self.kv_cache = Some((k.clone(), v.clone()));
let n_rep = (self.num_heads / self.num_kv_heads) as i32;
let k = repeat_heads(&k, n_rep)?;
let v = repeat_heads(&v, n_rep)?;
let scores = ops::matmul(&q, &ops::transpose_axes(&k, &[0, 1, 3, 2])?)?;
let scores = if self.is_sliding {
ops::add(&scores, sliding_mask)?
} else if let Some(mask) = global_mask {
ops::add(&scores, mask)?
} else {
scores
};
let attn = ops::softmax_axis(&scores, -1, None)?;
let out = ops::matmul(&attn, &v)?;
let out = ops::transpose_axes(&out, &[0, 2, 1, 3])?;
let merged = ops::reshape(&out, &[batch, seq_len, (self.num_heads * self.head_dim) as i32])?;
self.o_proj.forward(&merged)
}
fn apply_rope(&self, x: &Array, offset: i32) -> Result<Array, mlx_rs::error::Exception> {
match &self.prop_freqs {
None => mlx_rs::fast::rope(
x,
self.head_dim as i32,
false,
self.rope_theta,
1.0,
offset,
None::<&Array>,
),
Some(freqs) => mlx_rs::fast::rope(
x,
self.head_dim as i32,
false,
None::<f32>,
1.0,
offset,
Some(freqs),
),
}
}
}
struct Gemma4Block {
input_layernorm: GemmaRmsNorm,
attn: Gemma4Attention,
post_attention_layernorm: GemmaRmsNorm,
pre_feedforward_layernorm: GemmaRmsNorm,
mlp: GemmaMlp,
post_feedforward_layernorm: GemmaRmsNorm,
layer_scalar: Option<Array>,
}
impl Gemma4Block {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
cfg: &Gemma4Config,
kind: LayerKind,
) -> Result<Self, InferenceError> {
let eps = cfg.rms_norm_eps;
Ok(Self {
input_layernorm: GemmaRmsNorm::load(tensors, &format!("{prefix}.input_layernorm"), eps)?,
attn: Gemma4Attention::load(tensors, &format!("{prefix}.self_attn"), cfg, kind)?,
post_attention_layernorm: GemmaRmsNorm::load(
tensors,
&format!("{prefix}.post_attention_layernorm"),
eps,
)?,
pre_feedforward_layernorm: GemmaRmsNorm::load(
tensors,
&format!("{prefix}.pre_feedforward_layernorm"),
eps,
)?,
mlp: GemmaMlp::load(tensors, &format!("{prefix}.mlp"), Some(&cfg.quant))?,
post_feedforward_layernorm: GemmaRmsNorm::load(
tensors,
&format!("{prefix}.post_feedforward_layernorm"),
eps,
)?,
layer_scalar: tensors.get(&format!("{prefix}.layer_scalar")).cloned(),
})
}
fn clear_cache(&mut self) {
self.attn.clear_cache();
}
fn truncate_cache(&mut self, len: i32) {
self.attn.truncate_cache(len);
}
fn forward(
&mut self,
x: &Array,
offset: i32,
sliding_mask: &Array,
global_mask: Option<&Array>,
) -> Result<Array, mlx_rs::error::Exception> {
let residual = x.clone();
let h = self.input_layernorm.forward(x)?;
let h = self.attn.forward(&h, offset, sliding_mask, global_mask)?;
let h = self.post_attention_layernorm.forward(&h)?;
let x = ops::add(&residual, &h)?;
let residual = x.clone();
let h = self.pre_feedforward_layernorm.forward(&x)?;
let h = self.mlp.forward(&h)?;
let h = self.post_feedforward_layernorm.forward(&h)?;
let h = ops::add(&residual, &h)?;
match &self.layer_scalar {
Some(s) => ops::multiply(&h, s),
None => Ok(h),
}
}
}
pub struct Gemma4Model {
config: Gemma4Config,
embed_tokens: QEmbedding,
lm_head: QLinear,
layers: Vec<Gemma4Block>,
final_norm: GemmaRmsNorm,
embed_scale: f32,
cached_tokens: Vec<u32>,
}
impl Gemma4Model {
fn from_tensors(
tensors: &HashMap<String, Array>,
config: Gemma4Config,
) -> Result<Self, InferenceError> {
let pfx = "language_model.model";
let embed_tokens =
build_qembedding(tensors, &format!("{pfx}.embed_tokens"), Some(&config.quant))?;
let lm_head = build_qlinear(tensors, &format!("{pfx}.embed_tokens"), Some(&config.quant))?;
let final_norm = GemmaRmsNorm::load(tensors, &format!("{pfx}.norm"), config.rms_norm_eps)?;
let mut layers = Vec::with_capacity(config.num_hidden_layers);
for i in 0..config.num_hidden_layers {
let kind = config.layer_types[i];
layers.push(Gemma4Block::load(
tensors,
&format!("{pfx}.layers.{i}"),
&config,
kind,
)?);
}
let embed_scale = (config.hidden_size as f32).sqrt();
Ok(Self {
config,
embed_tokens,
lm_head,
layers,
final_norm,
embed_scale,
cached_tokens: Vec::new(),
})
}
fn clear_cache(&mut self) {
for layer in &mut self.layers {
layer.clear_cache();
}
self.cached_tokens.clear();
}
fn begin_prompt(&mut self, tokens: &[u32]) -> usize {
let cap = tokens.len().saturating_sub(1);
let mut common = 0;
while common < self.cached_tokens.len()
&& common < cap
&& self.cached_tokens[common] == tokens[common]
{
common += 1;
}
if common < self.cached_tokens.len() {
for layer in &mut self.layers {
layer.truncate_cache(common as i32);
}
self.cached_tokens.truncate(common);
}
common
}
fn forward(&mut self, tokens: &[u32], offset: i32) -> Result<Vec<f32>, InferenceError> {
let map_err = |e: mlx_rs::error::Exception| InferenceError::InferenceFailed(e.to_string());
self.cached_tokens.extend_from_slice(tokens);
let ids: Vec<i32> = tokens.iter().map(|&t| t as i32).collect();
let seq_len = ids.len() as i32;
let id_arr = Array::from_slice(&ids, &[1, seq_len]);
let mut h = self.embed_tokens.forward(&id_arr).map_err(map_err)?;
h = ops::multiply(&h, &Array::from_f32(self.embed_scale)).map_err(map_err)?;
let dump = std::env::var("CAR_GEMMA4_DUMP").ok();
if let Some(d) = &dump {
dump_tensor(d, "embed", &h);
}
let kv_len = offset + seq_len;
let dtype = h.dtype();
let sliding_mask =
build_windowed_causal_mask(seq_len, kv_len, self.config.sliding_window as i32, dtype)
.map_err(map_err)?;
let global_mask = if seq_len > 1 {
Some(build_windowed_causal_mask(seq_len, kv_len, i32::MAX, dtype).map_err(map_err)?)
} else {
None
};
for (i, layer) in self.layers.iter_mut().enumerate() {
h = layer
.forward(&h, offset, &sliding_mask, global_mask.as_ref())
.map_err(map_err)?;
if let Some(d) = &dump {
if i == 0 || i == 5 {
dump_tensor(d, &format!("layer{i}"), &h);
}
}
}
h = self.final_norm.forward(&h).map_err(map_err)?;
let last = h.index((.., seq_len - 1, ..)); let mut logits = self.lm_head.forward(&last).map_err(map_err)?;
if let Some(cap) = self.config.final_logit_softcapping {
let cap_a = Array::from_f32(cap);
let scaled = ops::divide(&logits, &cap_a).map_err(map_err)?;
logits = ops::multiply(&ops::tanh(&scaled).map_err(map_err)?, &cap_a).map_err(map_err)?;
}
let logits = logits.as_dtype(mlx_rs::Dtype::Float32).map_err(map_err)?;
mlx_rs::transforms::eval([&logits]).map_err(map_err)?;
Ok(logits.as_slice().to_vec())
}
fn prefill_logits(&mut self, tokens: &[u32]) -> Result<Vec<f32>, InferenceError> {
self.clear_cache();
self.forward(tokens, 0)
}
}
pub struct Gemma4Backend {
model: Gemma4Model,
tokenizer: Tokenizer,
context_length: usize,
chat_template: Option<ChatTemplate>,
eos_token_ids: Vec<u32>,
}
unsafe impl Send for Gemma4Backend {}
unsafe impl Sync for Gemma4Backend {}
impl Gemma4Backend {
pub fn load(model_dir: &Path) -> Result<Self, InferenceError> {
info!(dir = %model_dir.display(), "loading Gemma 4 unified text backend");
let config = Gemma4Config::parse(model_dir)?;
let tokenizer = Tokenizer::from_file(model_dir.join("tokenizer.json"))
.map_err(|e| InferenceError::InferenceFailed(format!("load tokenizer: {e}")))?;
let mut tensors: HashMap<String, Array> = HashMap::new();
let index_path = model_dir.join("model.safetensors.index.json");
let shards: Vec<String> = if index_path.exists() {
let idx: serde_json::Value =
serde_json::from_str(&std::fs::read_to_string(&index_path).map_err(|e| {
InferenceError::InferenceFailed(format!("read index: {e}"))
})?)
.map_err(|e| InferenceError::InferenceFailed(format!("parse index: {e}")))?;
let mut set = std::collections::BTreeSet::new();
if let Some(map) = idx.get("weight_map").and_then(|m| m.as_object()) {
for v in map.values() {
if let Some(s) = v.as_str() {
set.insert(s.to_string());
}
}
}
set.into_iter().collect()
} else {
vec!["model.safetensors".to_string()]
};
let mut dropped = 0usize;
for shard in &shards {
let path = model_dir.join(shard);
let loaded = Array::load_safetensors(&path)
.map_err(|e| InferenceError::InferenceFailed(format!("load {shard}: {e}")))?;
for (k, v) in loaded {
if k.starts_with("embed_audio")
|| k.starts_with("embed_vision")
|| k.starts_with("vision_embedder")
{
dropped += 1;
continue;
}
tensors.insert(k, v);
}
}
info!(tensors = tensors.len(), dropped, "Gemma 4 weights loaded (multimodal towers dropped)");
let context_length = 131_072;
let eos_token_ids = config.eos_token_ids.clone();
let chat_template = ChatTemplate::load(model_dir)?;
let model = Gemma4Model::from_tensors(&tensors, config)?;
Ok(Self {
model,
tokenizer,
context_length,
chat_template,
eos_token_ids,
})
}
pub fn context_length(&self) -> usize {
self.context_length
}
pub fn eos_token_ids(&self) -> Vec<u32> {
self.eos_token_ids.clone()
}
pub fn chat_template(&self) -> Option<&ChatTemplate> {
self.chat_template.as_ref()
}
pub fn encode(&self, text: &str) -> Result<Vec<u32>, InferenceError> {
self.tokenizer
.encode(text, true)
.map(|e| e.get_ids().to_vec())
.map_err(|e| InferenceError::TokenizationError(e.to_string()))
}
pub fn decode(&self, tokens: &[u32]) -> Result<String, InferenceError> {
self.tokenizer
.decode(tokens, true)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))
}
pub fn forward(&mut self, tokens: &[u32], pos: usize) -> Result<Vec<f32>, InferenceError> {
self.model.forward(tokens, pos as i32)
}
pub fn clear_kv_cache(&mut self) {
self.model.clear_cache();
}
pub fn begin_prompt(&mut self, tokens: &[u32]) -> usize {
self.model.begin_prompt(tokens)
}
pub fn prefill_logits(&mut self, tokens: &[u32]) -> Result<Vec<f32>, InferenceError> {
self.model.prefill_logits(tokens)
}
}
const ESC: &str = "<|\"|>";
pub(crate) fn parse_gemma4_tool_calls(text: &str) -> (String, Vec<ToolCall>) {
let mut calls = Vec::new();
let mut clean = String::new();
let mut cursor = 0;
let mut from = 0;
while let Some(rel) = text[from..].find("call:") {
let cs = from + rel;
let after = cs + "call:".len();
let name_end = text[after..]
.find(|c: char| !(c.is_alphanumeric() || c == '_' || c == '-'))
.map(|i| after + i)
.unwrap_or(text.len());
if name_end > after && text[name_end..].starts_with('{') {
let close = find_matching_brace(text, name_end);
let args = parse_object(&text[name_end + 1..close.min(text.len())]);
calls.push(ToolCall {
id: None,
name: text[after..name_end].to_string(),
arguments: args,
});
clean.push_str(&text[cursor..cs]);
cursor = (close + 1).min(text.len());
from = cursor;
} else {
from = name_end.max(after + 1);
}
}
clean.push_str(&text[cursor..]);
(clean.trim().to_string(), calls)
}
fn find_matching_brace(text: &str, start: usize) -> usize {
let mut depth = 0i32;
let mut i = start;
while i < text.len() {
if text[i..].starts_with(ESC) {
match text[i + ESC.len()..].find(ESC) {
Some(e) => i = i + ESC.len() + e + ESC.len(),
None => return text.len(),
}
continue;
}
let ch = text[i..].chars().next().unwrap();
match ch {
'{' | '[' => depth += 1,
'}' | ']' => {
depth -= 1;
if depth == 0 {
return i;
}
}
_ => {}
}
i += ch.len_utf8();
}
text.len()
}
fn split_top_level(text: &str, delim: char) -> Vec<&str> {
let mut parts = Vec::new();
let mut depth = 0i32;
let mut start = 0;
let mut i = 0;
while i < text.len() {
if text[i..].starts_with(ESC) {
match text[i + ESC.len()..].find(ESC) {
Some(e) => i = i + ESC.len() + e + ESC.len(),
None => i = text.len(),
}
continue;
}
let ch = text[i..].chars().next().unwrap();
match ch {
'{' | '[' => depth += 1,
'}' | ']' => depth -= 1,
c if c == delim && depth == 0 => {
parts.push(&text[start..i]);
start = i + ch.len_utf8();
}
_ => {}
}
i += ch.len_utf8();
}
if !text[start..].trim().is_empty() {
parts.push(&text[start..]);
}
parts
}
fn find_top_level_colon(text: &str) -> Option<usize> {
let mut i = 0;
while i < text.len() {
if text[i..].starts_with(ESC) {
match text[i + ESC.len()..].find(ESC) {
Some(e) => i = i + ESC.len() + e + ESC.len(),
None => return None,
}
continue;
}
let ch = text[i..].chars().next().unwrap();
if ch == ':' {
return Some(i);
}
i += ch.len_utf8();
}
None
}
fn parse_object(text: &str) -> std::collections::HashMap<String, serde_json::Value> {
let mut map = std::collections::HashMap::new();
for entry in split_top_level(text, ',') {
let entry = entry.trim();
if entry.is_empty() {
continue;
}
if let Some(colon) = find_top_level_colon(entry) {
let key = entry[..colon].trim().trim_matches('"').to_string();
map.insert(key, parse_value(&entry[colon + 1..]));
}
}
map
}
fn parse_value(text: &str) -> serde_json::Value {
let text = text.trim();
if let Some(rest) = text.strip_prefix(ESC) {
let inner = match rest.find(ESC) {
Some(e) => &rest[..e],
None => rest,
};
return serde_json::Value::String(inner.to_string());
}
if text.starts_with('{') {
let close = find_matching_brace(text, 0);
let obj = parse_object(&text[1..close.min(text.len())]);
return serde_json::Value::Object(obj.into_iter().collect());
}
if text.starts_with('[') {
let close = find_matching_brace(text, 0);
let items: Vec<serde_json::Value> = split_top_level(&text[1..close.min(text.len())], ',')
.into_iter()
.filter(|s| !s.trim().is_empty())
.map(parse_value)
.collect();
return serde_json::Value::Array(items);
}
serde_json::from_str(text).unwrap_or_else(|_| serde_json::Value::String(text.to_string()))
}
fn repeat_heads(x: &Array, n_rep: i32) -> Result<Array, mlx_rs::error::Exception> {
if n_rep == 1 {
return Ok(x.clone());
}
let s = x.shape();
let (b, h_kv, t, d) = (s[0], s[1], s[2], s[3]);
let x5 = ops::reshape(x, &[b, h_kv, 1, t, d])?;
let tiled = ops::tile(&x5, &[1, 1, n_rep, 1, 1])?;
ops::reshape(&tiled, &[b, h_kv * n_rep, t, d])
}
fn build_windowed_causal_mask(
l_q: i32,
l_k: i32,
window: i32,
dtype: mlx_rs::Dtype,
) -> Result<Array, mlx_rs::error::Exception> {
let rows = windowed_causal_mask_values(l_q, l_k, window);
Array::from_slice(&rows, &[1, 1, l_q, l_k]).as_dtype(dtype)
}
fn windowed_causal_mask_values(l_q: i32, l_k: i32, window: i32) -> Vec<f32> {
let off = (l_k - l_q) as i64;
let w = window as i64;
(0..l_q)
.flat_map(|i| {
let abs_q = off + i as i64;
let lo = abs_q - w + 1; (0..l_k).map(move |j| {
let j = j as i64;
if j <= abs_q && j >= lo {
0.0
} else {
f32::NEG_INFINITY
}
})
})
.collect()
}
fn dump_tensor(dir: &str, name: &str, t: &Array) {
let Ok(t_f32) = t.as_dtype(mlx_rs::Dtype::Float32) else {
return;
};
let _ = mlx_rs::transforms::eval([&t_f32]);
let data: &[f32] = t_f32.as_slice();
let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
let _ = std::fs::create_dir_all(dir);
let _ = std::fs::write(format!("{dir}/{name}.bin"), &bytes);
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn gemma4_dir() -> Option<PathBuf> {
let snaps = PathBuf::from(std::env::var("HOME").ok()?).join(
".cache/huggingface/hub/models--mlx-community--gemma-4-12B-it-4bit/snapshots",
);
std::fs::read_dir(snaps)
.ok()?
.flatten()
.map(|e| e.path())
.find(|p| p.join("config.json").exists())
}
#[test]
#[ignore = "requires mlx-community/gemma-4-12B-it-4bit weights in the HF cache"]
fn chat_generates_coherent_text() {
use crate::tasks::generate::{Message, ThinkingMode};
let Some(dir) = gemma4_dir() else {
eprintln!("SKIP: weights not present");
return;
};
let mut b = Gemma4Backend::load(&dir).expect("load");
let prompt = b
.chat_template()
.expect("template")
.render(
&[Message::User {
content: "What is the capital of France? Answer in one word.".into(),
}],
None,
ThinkingMode::Off,
)
.expect("render");
let toks = b.encode(&prompt).expect("encode");
b.clear_kv_cache();
let mut logits = b.forward(&toks, 0).expect("prefill");
let eos = b.eos_token_ids();
let mut out = Vec::new();
for step in 0..16 {
let next = logits
.iter()
.enumerate()
.max_by(|a, c| a.1.partial_cmp(c.1).unwrap())
.map(|(i, _)| i as u32)
.unwrap();
if eos.contains(&next) {
break;
}
out.push(next);
logits = b.forward(&[next], toks.len() + step).expect("decode");
}
let text = b.decode(&out).unwrap_or_default();
eprintln!("CHAT OUT: {text:?}");
assert!(
text.to_lowercase().contains("paris"),
"expected a coherent answer mentioning Paris, got {text:?}"
);
}
#[test]
#[ignore = "requires mlx-community/gemma-4-12B-it-4bit weights in the HF cache"]
fn tool_use_generation_bounds() {
use crate::tasks::generate::{Message, ThinkingMode};
let Some(dir) = gemma4_dir() else {
return;
};
let mut b = Gemma4Backend::load(&dir).expect("load");
let tool = serde_json::json!([{
"type": "function",
"function": {
"name": "file_write",
"description": "Write content to a file",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "file path"},
"content": {"type": "string", "description": "file content"}
},
"required": ["path", "content"]
}
}
}]);
let prompt = b
.chat_template()
.expect("template")
.render(
&[Message::User {
content: "Create a file named hello.txt containing the text hi. Use the file_write tool.".into(),
}],
tool.as_array().map(|a| a.as_slice()),
ThinkingMode::Off,
)
.expect("render");
let toks = b.encode(&prompt).expect("encode");
b.clear_kv_cache();
let start = std::time::Instant::now();
let mut logits = b.forward(&toks, 0).expect("prefill");
let eos = b.eos_token_ids();
let mut out = Vec::new();
for step in 0..512 {
let next = logits
.iter()
.enumerate()
.max_by(|a, c| a.1.partial_cmp(c.1).unwrap())
.map(|(i, _)| i as u32)
.unwrap();
if eos.contains(&next) {
break;
}
out.push(next);
logits = b.forward(&[next], toks.len() + step).expect("decode");
}
let elapsed = start.elapsed();
let toks_per_s = out.len() as f64 / elapsed.as_secs_f64();
eprintln!(
"GEN {} tokens in {:.1}s ({:.1} tok/s); stopped_at_eos={}; OUT={:?}",
out.len(),
elapsed.as_secs_f64(),
toks_per_s,
out.len() < 512,
b.decode(&out).unwrap_or_default()
);
}
#[test]
#[ignore = "requires mlx-community/gemma-4-12B-it-4bit weights in the HF cache"]
fn concludes_after_tool_result() {
use crate::tasks::generate::{Message, ThinkingMode, ToolCall};
use std::collections::HashMap;
let Some(dir) = gemma4_dir() else {
return;
};
let mut b = Gemma4Backend::load(&dir).expect("load");
let tool = serde_json::json!([
{
"type": "function",
"function": {
"name": "write_file",
"description": "Write content to a file",
"parameters": {
"type": "object",
"properties": {"path": {"type": "string"}, "content": {"type": "string"}},
"required": ["path", "content"]
}
}
},
{
"type": "function",
"function": {
"name": "read_file",
"description": "Read the content of a file",
"parameters": {
"type": "object",
"properties": {"path": {"type": "string"}},
"required": ["path"]
}
}
}
]);
let mut args = HashMap::new();
args.insert("path".to_string(), serde_json::json!("output.txt"));
args.insert("content".to_string(), serde_json::json!("hello world"));
let msgs = vec![
Message::System {
content: "You are a helpful AI assistant. Complete the task using the tools available to you.".into(),
},
Message::User {
content: "Create a file called output.txt containing 'hello world'.".into(),
},
Message::Assistant {
content: String::new(),
tool_calls: vec![ToolCall {
id: Some("call_0_0".into()),
name: "write_file".into(),
arguments: args,
}],
},
Message::ToolResult {
tool_use_id: "call_0_0".into(),
content: String::new(),
},
];
let prompt = b
.chat_template()
.expect("template")
.render(&msgs, tool.as_array().map(|a| a.as_slice()), ThinkingMode::Off)
.expect("render");
let toks = b.encode(&prompt).expect("encode");
b.clear_kv_cache();
let mut logits = b.forward(&toks, 0).expect("prefill");
let eos = b.eos_token_ids();
let mut out = Vec::new();
let mut stopped = false;
for step in 0..300 {
let next = logits
.iter()
.enumerate()
.max_by(|a, c| a.1.partial_cmp(c.1).unwrap())
.map(|(i, _)| i as u32)
.unwrap();
if eos.contains(&next) {
stopped = true;
break;
}
out.push(next);
logits = b.forward(&[next], toks.len() + step).expect("decode");
}
let text = b.decode(&out).unwrap_or_default();
eprintln!("AFTER-TOOL OUT ({} toks, stopped={stopped}): {text:?}", out.len());
assert!(stopped, "model did not stop (rambled to the cap): {text:?}");
assert!(
!text.contains("call:"),
"model emitted another tool call instead of concluding: {text:?}"
);
}
#[test]
#[ignore = "requires mlx-community/gemma-4-12B-it-4bit weights in the HF cache"]
fn prefix_reuse_matches_full_prefill() {
let Some(dir) = gemma4_dir() else {
return;
};
let mut b = Gemma4Backend::load(&dir).expect("load");
let p1 = b
.encode("The capital of France is Paris. The capital of Italy is")
.expect("encode");
let argmax = |l: &[f32]| {
l.iter()
.enumerate()
.max_by(|a, c| a.1.partial_cmp(c.1).unwrap())
.map(|(i, _)| i)
.unwrap()
};
b.clear_kv_cache();
let o0 = b.begin_prompt(&p1);
assert_eq!(o0, 0, "fresh cache must prefill from 0");
let full = b.forward(&p1[o0..], o0).expect("full prefill");
let half = p1.len() / 2;
b.clear_kv_cache();
let oa = b.begin_prompt(&p1[..half]);
let _ = b.forward(&p1[..half][oa..], oa).expect("prefix prefill");
let off = b.begin_prompt(&p1);
assert_eq!(off, half, "expected to reuse the whole prefix (off={off}, half={half})");
let reused = b.forward(&p1[off..], off).expect("suffix prefill");
assert_eq!(
argmax(&full),
argmax(&reused),
"prefix-reuse must match full prefill"
);
}
#[test]
#[ignore = "requires mlx-community/gemma-4-12B-it-4bit weights in the HF cache"]
fn matches_reference_on_long_context() {
let Some(dir) = gemma4_dir() else {
return;
};
let mut b = Gemma4Backend::load(&dir).expect("load");
let prompt = "The quick brown fox jumps over the lazy dog. ".repeat(200)
+ "The capital of France is";
let toks = b.encode(&prompt).expect("encode");
assert!(
toks.len() > 1024,
"prompt must exceed the sliding window (got {})",
toks.len()
);
let logits = b.prefill_logits(&toks).expect("prefill");
let argmax = logits
.iter()
.enumerate()
.max_by(|a, c| a.1.partial_cmp(c.1).unwrap())
.map(|(i, _)| i)
.unwrap();
eprintln!("long-context ({} toks) argmax={argmax} (reference=1048)", toks.len());
assert_eq!(argmax, 1048, "sliding-window long-context must match mlx_lm");
}
#[test]
fn sliding_window_mask_restricts_to_window() {
let v = windowed_causal_mask_values(4, 4, 2);
let at = |i: usize, j: usize| v[i * 4 + j];
assert_eq!(at(0, 0), 0.0);
assert!(at(0, 1).is_infinite()); assert!(at(2, 0).is_infinite()); assert_eq!(at(2, 1), 0.0);
assert_eq!(at(2, 2), 0.0);
assert!(at(2, 3).is_infinite()); let d = windowed_causal_mask_values(1, 10, 3); assert!(d[5].is_infinite());
assert_eq!(d[7], 0.0);
assert_eq!(d[9], 0.0);
}
#[test]
fn unbounded_window_is_plain_causal() {
let v = windowed_causal_mask_values(3, 3, i32::MAX);
let at = |i: usize, j: usize| v[i * 3 + j];
assert_eq!(at(0, 0), 0.0);
assert!(at(0, 2).is_infinite());
assert_eq!(at(2, 0), 0.0); assert_eq!(at(2, 2), 0.0);
}
#[test]
fn parses_gemma4_tool_call() {
let text = "call:get_weather{city:<|\"|>New York<|\"|>,days:3,metric:true}";
let (clean, calls) = parse_gemma4_tool_calls(text);
assert_eq!(calls.len(), 1, "clean={clean:?}");
assert_eq!(calls[0].name, "get_weather");
assert_eq!(
calls[0].arguments.get("city").unwrap(),
&serde_json::json!("New York")
);
assert_eq!(
calls[0].arguments.get("days").unwrap(),
&serde_json::json!(3)
);
assert_eq!(
calls[0].arguments.get("metric").unwrap(),
&serde_json::json!(true)
);
}
#[test]
fn tool_call_with_surrounding_text_and_no_call_is_clean() {
let (clean, calls) = parse_gemma4_tool_calls("Just a normal answer.");
assert!(calls.is_empty());
assert_eq!(clean, "Just a normal answer.");
}
#[test]
#[ignore = "requires mlx-community/gemma-4-12B-it-4bit weights in the HF cache"]
fn loads_and_prefills() {
let Some(dir) = gemma4_dir() else {
eprintln!("SKIP: gemma-4-12B-it-4bit weights not in HF cache");
return;
};
let mut backend = Gemma4Backend::load(&dir).expect("load gemma4 backend");
let toks = backend.encode("The capital of France is").expect("encode");
assert!(!toks.is_empty());
let logits = backend.prefill_logits(&toks).expect("prefill forward");
assert!(
logits.len() >= 262_000,
"expected ~262k vocab logits, got {}",
logits.len()
);
assert!(
logits.iter().take(1000).all(|v| v.is_finite()),
"logits contain non-finite values"
);
let mut idx: Vec<usize> = (0..logits.len()).collect();
idx.sort_by(|&a, &b| logits[b].partial_cmp(&logits[a]).unwrap());
eprintln!("gemma4 prefill OK: {} logits. Top-5 next tokens:", logits.len());
for &t in idx.iter().take(5) {
eprintln!(
" {} {:?} logit={:.4}",
t,
backend.decode(&[t as u32]),
logits[t]
);
}
assert_eq!(idx[0], 236772, "argmax must match the mlx_lm reference");
let t1 = idx[0] as u32;
let decode_logits = backend.forward(&[t1], toks.len()).expect("decode step");
let decode_argmax = decode_logits
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(i, _)| i)
.unwrap();
let mut extended = toks.clone();
extended.push(t1);
let full = backend.prefill_logits(&extended).expect("full reprefill");
let full_argmax = full
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(i, _)| i)
.unwrap();
eprintln!("decode argmax={decode_argmax} vs full-reprefill argmax={full_argmax}");
assert_eq!(
decode_argmax, full_argmax,
"KV-cache decode must match a fresh full prefill"
);
}
}