use std::collections::HashMap;
use std::path::{Path, PathBuf};
use mlx_rs::module::{Module, ModuleParameters, Param};
use mlx_rs::nn;
use mlx_rs::ops;
use mlx_rs::ops::indexing::IndexOp;
use mlx_rs::Array;
use tokenizers::Tokenizer;
use tracing::info;
use super::mlx::{build_qembedding, build_qlinear, QEmbedding, QLinear, QuantConfig};
use crate::InferenceError;
fn build_causal_mask_additive(
seq_len: i32,
dtype: mlx_rs::Dtype,
) -> Result<Array, mlx_rs::error::Exception> {
let rows: Vec<f32> = (0..seq_len)
.flat_map(|i| (0..seq_len).map(move |j| if j <= i { 0.0 } else { f32::NEG_INFINITY }))
.collect();
let mask = Array::from_slice(&rows, &[1, 1, seq_len, seq_len]);
mask.as_dtype(dtype)
}
fn dump_hidden(dir: &str, name: &str, t: &Array) {
let t_f32 = match t.as_dtype(mlx_rs::Dtype::Float32) {
Ok(a) => a,
Err(_) => return,
};
let _ = mlx_rs::transforms::eval([&t_f32]);
let shape = t_f32.shape().to_vec();
let data: &[f32] = t_f32.as_slice();
let bin_path = format!("{dir}/{name}.bin");
let meta_path = format!("{dir}/{name}.meta");
let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
let _ = std::fs::write(&bin_path, &bytes);
let _ = std::fs::write(&meta_path, format!("{shape:?}\n"));
}
pub struct Gemma3Config {
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub head_dim: usize,
pub intermediate_size: usize,
pub vocab_size: usize,
pub rms_norm_eps: f32,
pub layer_types: Vec<LayerKind>,
pub rope_theta_global: f32,
pub rope_theta_sliding: f32,
pub partial_rotary_factor: f32,
pub quant: QuantConfig,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LayerKind {
Sliding,
Full,
}
impl Default for Gemma3Config {
fn default() -> Self {
let mut layer_types = Vec::with_capacity(48);
for i in 0..48 {
if (i + 1) % 6 == 0 {
layer_types.push(LayerKind::Full);
} else {
layer_types.push(LayerKind::Sliding);
}
}
Self {
hidden_size: 3840,
num_hidden_layers: 48,
num_attention_heads: 16,
num_key_value_heads: 8,
head_dim: 256,
intermediate_size: 15360,
vocab_size: 262208,
rms_norm_eps: 1e-6,
layer_types,
rope_theta_global: 1_000_000.0,
rope_theta_sliding: 10_000.0,
partial_rotary_factor: 0.25,
quant: QuantConfig {
group_size: 64,
bits: 4,
},
}
}
}
struct GemmaRmsNorm {
weight: Array,
eps: f32,
}
impl GemmaRmsNorm {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
eps: f32,
) -> Result<Self, InferenceError> {
let weight = tensors
.get(&format!("{prefix}.weight"))
.cloned()
.ok_or_else(|| InferenceError::InferenceFailed(format!("missing {prefix}.weight")))?;
Ok(Self { weight, eps })
}
fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let eps = Array::from_f32(self.eps);
let var = ops::multiply(x, x)?.mean_axes(&[-1], true)?;
let scale = ops::rsqrt(&ops::add(&var, &eps)?)?;
let normed = ops::multiply(x, &scale)?;
let one = Array::from_f32(1.0);
let weight_plus_one = ops::add(&self.weight, &one)?;
ops::multiply(&normed, &weight_plus_one)
}
}
struct GemmaMlp {
gate_proj: QLinear,
up_proj: QLinear,
down_proj: QLinear,
}
impl GemmaMlp {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
quant: Option<&QuantConfig>,
) -> Result<Self, InferenceError> {
Ok(Self {
gate_proj: build_qlinear(tensors, &format!("{prefix}.gate_proj"), quant)?,
up_proj: build_qlinear(tensors, &format!("{prefix}.up_proj"), quant)?,
down_proj: build_qlinear(tensors, &format!("{prefix}.down_proj"), quant)?,
})
}
fn forward(&mut self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
let gate = nn::gelu(&self.gate_proj.forward(x)?)?; let up = self.up_proj.forward(x)?;
self.down_proj.forward(&ops::multiply(&gate, &up)?)
}
}
struct GemmaAttention {
q_proj: QLinear,
k_proj: QLinear,
v_proj: QLinear,
o_proj: QLinear,
q_norm: GemmaRmsNorm,
k_norm: GemmaRmsNorm,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
rope_theta: f32,
rope_dim: usize,
}
impl GemmaAttention {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
cfg: &Gemma3Config,
layer_kind: LayerKind,
) -> Result<Self, InferenceError> {
let quant = Some(&cfg.quant);
let (rope_theta, rope_dim) = match layer_kind {
LayerKind::Sliding => (cfg.rope_theta_sliding, cfg.head_dim),
LayerKind::Full => (
cfg.rope_theta_global,
((cfg.head_dim as f32) * cfg.partial_rotary_factor).round() as usize,
),
};
Ok(Self {
q_proj: build_qlinear(tensors, &format!("{prefix}.q_proj"), quant)?,
k_proj: build_qlinear(tensors, &format!("{prefix}.k_proj"), quant)?,
v_proj: build_qlinear(tensors, &format!("{prefix}.v_proj"), quant)?,
o_proj: build_qlinear(tensors, &format!("{prefix}.o_proj"), quant)?,
q_norm: GemmaRmsNorm::load(tensors, &format!("{prefix}.q_norm"), cfg.rms_norm_eps)?,
k_norm: GemmaRmsNorm::load(tensors, &format!("{prefix}.k_norm"), cfg.rms_norm_eps)?,
num_heads: cfg.num_attention_heads,
num_kv_heads: cfg.num_key_value_heads,
head_dim: cfg.head_dim,
rope_theta,
rope_dim,
})
}
fn forward(
&mut self,
x: &Array,
combined_mask: Option<&Array>,
) -> Result<Array, mlx_rs::error::Exception> {
let s = x.shape();
let (batch, seq_len) = (s[0], s[1]);
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
let reshape_q = ops::reshape(
&q,
&[batch, seq_len, self.num_heads as i32, self.head_dim as i32],
)?;
let reshape_k = ops::reshape(
&k,
&[
batch,
seq_len,
self.num_kv_heads as i32,
self.head_dim as i32,
],
)?;
let reshape_v = ops::reshape(
&v,
&[
batch,
seq_len,
self.num_kv_heads as i32,
self.head_dim as i32,
],
)?;
let q = ops::transpose_axes(&reshape_q, &[0, 2, 1, 3])?;
let k = ops::transpose_axes(&reshape_k, &[0, 2, 1, 3])?;
let v = ops::transpose_axes(&reshape_v, &[0, 2, 1, 3])?;
let q = self.q_norm.forward(&q)?;
let k = self.k_norm.forward(&k)?;
let q = self.apply_rope(&q)?;
let k = self.apply_rope(&k)?;
let n_rep = self.num_heads / self.num_kv_heads;
let k = self.repeat_heads(&k, n_rep as i32, seq_len)?;
let v = self.repeat_heads(&v, n_rep as i32, seq_len)?;
let scale = Array::from_f32(1.0 / (self.head_dim as f32).sqrt());
let scores = ops::multiply(
&ops::matmul(&q, &ops::transpose_axes(&k, &[0, 1, 3, 2])?)?,
&scale,
)?;
let mask = match combined_mask {
Some(m) => m.as_dtype(scores.dtype())?,
None => build_causal_mask_additive(seq_len, scores.dtype())?,
};
let masked = ops::add(&scores, &mask)?;
let attn = ops::softmax_axis(&masked, -1, None)?;
let out = ops::matmul(&attn, &v)?;
let out = ops::transpose_axes(&out, &[0, 2, 1, 3])?;
let merged = ops::reshape(
&out,
&[batch, seq_len, (self.num_heads * self.head_dim) as i32],
)?;
self.o_proj.forward(&merged)
}
fn apply_rope(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
if self.rope_dim == self.head_dim {
return mlx_rs::fast::rope(
x,
self.head_dim as i32,
false,
self.rope_theta,
1.0,
0,
None::<&Array>,
);
}
let rd = self.rope_dim as i32;
let head_dim = self.head_dim as i32;
let rot = x.index((.., .., .., ..rd));
let pass = x.index((.., .., .., rd..));
let rotated = mlx_rs::fast::rope(&rot, rd, false, self.rope_theta, 1.0, 0, None::<&Array>)?;
let combined = ops::concatenate_axis(&[&rotated, &pass], -1)?;
let s = combined.shape();
debug_assert_eq!(s[s.len() - 1], head_dim);
Ok(combined)
}
fn repeat_heads(
&self,
x: &Array,
n_rep: i32,
seq_len: i32,
) -> Result<Array, mlx_rs::error::Exception> {
if n_rep == 1 {
return Ok(x.clone());
}
let s = x.shape();
let (b, h_kv, _t, d) = (s[0], s[1], s[2], s[3]);
let x5 = ops::reshape(x, &[b, h_kv, 1, seq_len, d])?;
let tiled = ops::tile(&x5, &[1, 1, n_rep, 1, 1])?;
ops::reshape(&tiled, &[b, h_kv * n_rep, seq_len, d])
}
}
struct GemmaBlock {
input_layernorm: GemmaRmsNorm,
attn: GemmaAttention,
post_attention_layernorm: GemmaRmsNorm,
pre_feedforward_layernorm: GemmaRmsNorm,
mlp: GemmaMlp,
post_feedforward_layernorm: GemmaRmsNorm,
}
impl GemmaBlock {
fn load(
tensors: &HashMap<String, Array>,
prefix: &str,
cfg: &Gemma3Config,
layer_kind: LayerKind,
) -> Result<Self, InferenceError> {
let eps = cfg.rms_norm_eps;
Ok(Self {
input_layernorm: GemmaRmsNorm::load(
tensors,
&format!("{prefix}.input_layernorm"),
eps,
)?,
attn: GemmaAttention::load(tensors, &format!("{prefix}.self_attn"), cfg, layer_kind)?,
post_attention_layernorm: GemmaRmsNorm::load(
tensors,
&format!("{prefix}.post_attention_layernorm"),
eps,
)?,
pre_feedforward_layernorm: GemmaRmsNorm::load(
tensors,
&format!("{prefix}.pre_feedforward_layernorm"),
eps,
)?,
mlp: GemmaMlp::load(tensors, &format!("{prefix}.mlp"), Some(&cfg.quant))?,
post_feedforward_layernorm: GemmaRmsNorm::load(
tensors,
&format!("{prefix}.post_feedforward_layernorm"),
eps,
)?,
})
}
fn forward(
&mut self,
x: &Array,
mask: Option<&Array>,
) -> Result<Array, mlx_rs::error::Exception> {
let residual = x.clone();
let h = self.input_layernorm.forward(x)?;
let h = self.attn.forward(&h, mask)?;
let h = self.post_attention_layernorm.forward(&h)?;
let x = ops::add(&residual, &h)?;
let residual = x.clone();
let h = self.pre_feedforward_layernorm.forward(&x)?;
let h = self.mlp.forward(&h)?;
let h = self.post_feedforward_layernorm.forward(&h)?;
ops::add(&residual, &h)
}
}
pub struct Gemma3TextEncoder {
config: Gemma3Config,
embed_tokens: QEmbedding,
layers: Vec<GemmaBlock>,
final_norm: GemmaRmsNorm,
tokenizer: Tokenizer,
embed_scale: f32,
}
unsafe impl Send for Gemma3TextEncoder {}
unsafe impl Sync for Gemma3TextEncoder {}
impl Gemma3TextEncoder {
pub fn try_load_default() -> Result<Option<Self>, InferenceError> {
let snapshots = std::env::var("HOME")
.map(PathBuf::from)
.unwrap_or_default()
.join(".cache/huggingface/hub/models--mlx-community--gemma-3-12b-it-4bit/snapshots");
if !snapshots.exists() {
return Ok(None);
}
let Some(snapshot) = std::fs::read_dir(&snapshots)
.ok()
.and_then(|entries| entries.flatten().find_map(|e| Some(e.path())))
else {
return Ok(None);
};
Self::load(&snapshot).map(Some)
}
pub fn load(model_dir: &Path) -> Result<Self, InferenceError> {
info!(dir = %model_dir.display(), "loading Gemma 3 12B text encoder");
let tokenizer_path = model_dir.join("tokenizer.json");
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| InferenceError::InferenceFailed(format!("load tokenizer: {e}")))?;
let mut tensors: HashMap<String, Array> = HashMap::new();
for shard in [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
] {
let path = model_dir.join(shard);
if !path.exists() {
return Err(InferenceError::InferenceFailed(format!(
"missing Gemma shard: {}",
path.display()
)));
}
let loaded = Array::load_safetensors(&path)
.map_err(|e| InferenceError::InferenceFailed(format!("load {shard}: {e}")))?;
for (k, v) in loaded {
tensors.insert(k, v);
}
}
info!(tensors = tensors.len(), "Gemma shards loaded");
let cfg = Gemma3Config::default();
let text_pfx = "language_model.model";
let embed_tokens = build_qembedding(
&tensors,
&format!("{text_pfx}.embed_tokens"),
Some(&cfg.quant),
)?;
let final_norm =
GemmaRmsNorm::load(&tensors, &format!("{text_pfx}.norm"), cfg.rms_norm_eps)?;
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
for i in 0..cfg.num_hidden_layers {
let kind = cfg.layer_types[i];
let block = GemmaBlock::load(&tensors, &format!("{text_pfx}.layers.{i}"), &cfg, kind)?;
layers.push(block);
}
info!(layers = layers.len(), "Gemma 3 decoder blocks loaded");
let embed_scale = (cfg.hidden_size as f32).sqrt();
Ok(Self {
config: cfg,
embed_tokens,
layers,
final_norm,
tokenizer,
embed_scale,
})
}
pub fn encode_for_ltx(
&mut self,
prompt: &str,
max_tokens: usize,
) -> Result<(Array, usize), InferenceError> {
let map_err = |e: mlx_rs::error::Exception| InferenceError::InferenceFailed(e.to_string());
let encoding = self
.tokenizer
.encode(prompt, true)
.map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
let raw_ids: Vec<i32> = encoding.get_ids().iter().map(|&id| id as i32).collect();
let n_valid = raw_ids.len().min(max_tokens);
let ids: Vec<i32> = if raw_ids.len() >= max_tokens {
raw_ids[..max_tokens].to_vec()
} else {
let pad_count = max_tokens - n_valid;
let mut padded = vec![0i32; max_tokens];
padded[pad_count..].copy_from_slice(&raw_ids);
padded
};
let seq_len = ids.len() as i32;
let token_ids = Array::from_slice(&ids, &[1, seq_len]);
let mut hidden = self.embed_tokens.forward(&token_ids).map_err(map_err)?;
let scale = Array::from_f32(self.embed_scale);
hidden = ops::multiply(&hidden, &scale).map_err(map_err)?;
let dump_dir = std::env::var("CAR_DUMP_GEMMA_HIDDEN").ok();
if let Some(ref dir) = dump_dir {
std::fs::create_dir_all(dir).ok();
info!(
tokens = ?ids.iter().take(10).copied().collect::<Vec<_>>(),
"Gemma parity dump: prompt tokens (first 10)"
);
dump_hidden(dir, "hidden_000_embed", &hidden);
}
let mut all_hidden: Vec<Array> = Vec::with_capacity(self.config.num_hidden_layers + 1);
all_hidden.push(hidden.clone());
let t = seq_len as usize;
let pad_count = t.saturating_sub(n_valid);
const NEG_BIG: f32 = -1.0e9;
let mut mask_vals = vec![0.0f32; t * t];
for i in 0..t {
for j in 0..t {
let causal = if j <= i { 0.0 } else { NEG_BIG };
let pad = if j < pad_count { NEG_BIG } else { 0.0 };
mask_vals[i * t + j] = causal + pad;
}
}
let combined_mask = Array::from_slice(&mask_vals, &[1, 1, t as i32, t as i32]);
for (i, block) in self.layers.iter_mut().enumerate() {
hidden = block
.forward(&hidden, Some(&combined_mask))
.map_err(map_err)?;
all_hidden.push(hidden.clone());
if let Some(ref dir) = dump_dir {
dump_hidden(dir, &format!("hidden_{:03}_block{:02}", i + 1, i), &hidden);
}
}
let final_hidden = self.final_norm.forward(&hidden).map_err(map_err)?;
if let Some(ref dir) = dump_dir {
dump_hidden(dir, "hidden_final_norm", &final_hidden);
}
let ltx_dump_dir = std::env::var("CAR_DUMP_LTX_STAGE").ok();
if let Some(ref dir) = ltx_dump_dir {
let _ = std::fs::create_dir_all(dir);
if let Ok(h_f32) = hidden.as_dtype(mlx_rs::Dtype::Float32) {
let _ = mlx_rs::transforms::eval([&h_f32]);
let shape = h_f32.shape().to_vec();
let data: &[f32] = h_f32.as_slice();
let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
let _ = std::fs::write(format!("{dir}/gemma_hidden_final.bin"), &bytes);
let _ = std::fs::write(
format!("{dir}/gemma_hidden_final.meta"),
format!("{shape:?}\n"),
);
}
}
let _ = final_hidden;
let stacked = {
let with_new_axis: Vec<Array> = all_hidden
.iter()
.map(|h| {
let s = h.shape();
ops::expand_dims(h, -1).unwrap_or_else(|_| {
ops::reshape(h, &[s[0], s[1], s[2], 1]).expect("reshape")
})
})
.collect();
let refs: Vec<&Array> = with_new_axis.iter().collect();
ops::concatenate_axis(&refs, -1).map_err(map_err)?
};
let eps = Array::from_f32(self.config.rms_norm_eps);
let var = ops::multiply(&stacked, &stacked)
.map_err(map_err)?
.mean_axes(&[2], true)
.map_err(map_err)?;
let scale = ops::rsqrt(&ops::add(&var, &eps).map_err(map_err)?).map_err(map_err)?;
let normed = ops::multiply(&stacked, &scale).map_err(map_err)?;
let s = normed.shape();
let feat_dim = (self.config.hidden_size * (self.config.num_hidden_layers + 1)) as i32;
let mut out = ops::reshape(&normed, &[s[0], s[1], feat_dim]).map_err(map_err)?;
let t = s[1] as usize;
if n_valid < t {
let pad_count = t - n_valid;
let mask_vals: Vec<f32> = (0..t)
.map(|i| if i < pad_count { 0.0 } else { 1.0 })
.collect();
let mask = Array::from_slice(&mask_vals, &[1, t as i32, 1]);
out = ops::multiply(&out, &mask).map_err(map_err)?;
}
if let Ok(dir) = std::env::var("CAR_DUMP_LTX_STAGE") {
let _ = std::fs::create_dir_all(&dir);
if let Ok(h_f32) = out.as_dtype(mlx_rs::Dtype::Float32) {
let _ = mlx_rs::transforms::eval([&h_f32]);
let shape = h_f32.shape().to_vec();
let data: &[f32] = h_f32.as_slice();
let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
let _ = std::fs::write(format!("{dir}/gemma_stacked.bin"), &bytes);
let _ = std::fs::write(format!("{dir}/gemma_stacked.meta"), format!("{shape:?}\n"));
}
}
Ok((out, n_valid))
}
}