car-inference 0.14.0

Local model inference for CAR — Candle backend with Qwen3 models
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
//! Gemma 3 12B text encoder — the text conditioner LTX-2.3 was trained with.
//!
//! Produces the 188160-dim per-token feature tensor that feeds
//! `TextEmbeddingConnector::video_aggregate_embed` / `audio_aggregate_embed`:
//!
//! 1. Tokenize prompt via Gemma tokenizer (256k vocab).
//! 2. Run the 48-layer text-only decoder.
//! 3. Collect 49 hidden states (embedding layer + 48 decoder outputs).
//! 4. Stack on a new last axis → `[B, T, 3840, 49]`.
//! 5. Per-token RMS-norm over the 3840 axis.
//! 6. Reshape → `[B, T, 188160]` where `188160 = 3840 × 49`.
//!
//! Weights come from `mlx-community/gemma-3-12b-it-4bit` (4-bit MLX quant,
//! group_size=64). We only use the text portion
//! (`language_model.model.*`) — the multimodal vision+projector heads are
//! ignored.
//!
//! Architectural notes vs vanilla Llama / Qwen3:
//! - **RMSNorm** uses `(weight + 1)` as the scale, not just `weight`.
//! - Each block has **four** layernorms: pre-attn, post-attn, pre-ff, post-ff,
//!   applied as `x + post_ln(attn(pre_ln(x)))` etc.
//! - Embedding tokens are scaled by `sqrt(hidden_size)` before entering the
//!   first block (Gemma 3 embedding normalizer).
//! - GQA: 16 query heads × 256 head_dim = 4096; 8 kv heads.
//! - Attention q/k have per-head RMS-norm (weight `[head_dim=256]`).
//! - Alternating sliding (1024 window, rope_theta=10000) and full (rope_theta=1M,
//!   partial_rotary_factor=0.25 — only first 64 of 256 head dims get RoPE).
//!   For short prompts (≤sliding_window), sliding and full attention produce
//!   the same result over the full seq. We implement full attention
//!   throughout and respect the per-layer RoPE variant.

use std::collections::HashMap;
use std::path::{Path, PathBuf};

use mlx_rs::module::{Module, ModuleParameters, Param};
use mlx_rs::nn;
use mlx_rs::ops;
use mlx_rs::ops::indexing::IndexOp;
use mlx_rs::Array;
use tokenizers::Tokenizer;
use tracing::info;

use super::mlx::{build_qembedding, build_qlinear, QEmbedding, QLinear, QuantConfig};
use crate::InferenceError;

/// Build a `[1, 1, T, T]` additive causal mask: zeros on and below the main
/// diagonal, `-inf` above (positions a token is forbidden from attending to).
/// Added to attention scores before softmax.
fn build_causal_mask_additive(
    seq_len: i32,
    dtype: mlx_rs::Dtype,
) -> Result<Array, mlx_rs::error::Exception> {
    let rows: Vec<f32> = (0..seq_len)
        .flat_map(|i| (0..seq_len).map(move |j| if j <= i { 0.0 } else { f32::NEG_INFINITY }))
        .collect();
    let mask = Array::from_slice(&rows, &[1, 1, seq_len, seq_len]);
    mask.as_dtype(dtype)
}

/// Save a `[1, T, D]` hidden state as a raw f32 binary file alongside a
/// `.meta` sidecar describing the shape. Activated per-encoding via the
/// `CAR_DUMP_GEMMA_HIDDEN` environment variable so it can be diff'd against
/// the upstream mlx_lm reference tensors.
fn dump_hidden(dir: &str, name: &str, t: &Array) {
    let t_f32 = match t.as_dtype(mlx_rs::Dtype::Float32) {
        Ok(a) => a,
        Err(_) => return,
    };
    let _ = mlx_rs::transforms::eval([&t_f32]);
    let shape = t_f32.shape().to_vec();
    let data: &[f32] = t_f32.as_slice();
    let bin_path = format!("{dir}/{name}.bin");
    let meta_path = format!("{dir}/{name}.meta");
    let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
    let _ = std::fs::write(&bin_path, &bytes);
    let _ = std::fs::write(&meta_path, format!("{shape:?}\n"));
}

/// Gemma 3 12B configuration used by LTX-2.3.
pub struct Gemma3Config {
    pub hidden_size: usize,
    pub num_hidden_layers: usize,
    pub num_attention_heads: usize,
    pub num_key_value_heads: usize,
    pub head_dim: usize,
    pub intermediate_size: usize,
    pub vocab_size: usize,
    pub rms_norm_eps: f32,
    /// Whether each layer is "full" or "sliding" attention. len = num_hidden_layers.
    pub layer_types: Vec<LayerKind>,
    pub rope_theta_global: f32,
    pub rope_theta_sliding: f32,
    /// Fraction of head_dim that receives rotary (global layers only).
    pub partial_rotary_factor: f32,
    pub quant: QuantConfig,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LayerKind {
    Sliding,
    Full,
}

impl Default for Gemma3Config {
    /// Defaults mirror the `mlx-community/gemma-3-12b-it-4bit` config.json.
    /// The `sliding_attention` × 5 + `full_attention` × 1 repeating pattern is
    /// encoded here directly rather than parsed at load time.
    fn default() -> Self {
        let mut layer_types = Vec::with_capacity(48);
        for i in 0..48 {
            // Pattern from config.json: 5 sliding then 1 full, repeated.
            if (i + 1) % 6 == 0 {
                layer_types.push(LayerKind::Full);
            } else {
                layer_types.push(LayerKind::Sliding);
            }
        }
        Self {
            hidden_size: 3840,
            num_hidden_layers: 48,
            num_attention_heads: 16,
            num_key_value_heads: 8,
            head_dim: 256,
            intermediate_size: 15360,
            vocab_size: 262208,
            rms_norm_eps: 1e-6,
            layer_types,
            rope_theta_global: 1_000_000.0,
            rope_theta_sliding: 10_000.0,
            partial_rotary_factor: 0.25,
            quant: QuantConfig {
                group_size: 64,
                bits: 4,
            },
        }
    }
}

// ─── RMSNorm (Gemma 3 variant with `weight + 1` scaling) ─────────────────

struct GemmaRmsNorm {
    weight: Array,
    eps: f32,
}

impl GemmaRmsNorm {
    fn load(
        tensors: &HashMap<String, Array>,
        prefix: &str,
        eps: f32,
    ) -> Result<Self, InferenceError> {
        let weight = tensors
            .get(&format!("{prefix}.weight"))
            .cloned()
            .ok_or_else(|| InferenceError::InferenceFailed(format!("missing {prefix}.weight")))?;
        Ok(Self { weight, eps })
    }

    fn forward(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
        let eps = Array::from_f32(self.eps);
        let var = ops::multiply(x, x)?.mean_axes(&[-1], true)?;
        let scale = ops::rsqrt(&ops::add(&var, &eps)?)?;
        let normed = ops::multiply(x, &scale)?;
        // Gemma 3: scale by (weight + 1), not weight.
        let one = Array::from_f32(1.0);
        let weight_plus_one = ops::add(&self.weight, &one)?;
        ops::multiply(&normed, &weight_plus_one)
    }
}

// ─── SwiGLU MLP ─────────────────────────────────────────────────────────

struct GemmaMlp {
    gate_proj: QLinear,
    up_proj: QLinear,
    down_proj: QLinear,
}

impl GemmaMlp {
    fn load(
        tensors: &HashMap<String, Array>,
        prefix: &str,
        quant: Option<&QuantConfig>,
    ) -> Result<Self, InferenceError> {
        Ok(Self {
            gate_proj: build_qlinear(tensors, &format!("{prefix}.gate_proj"), quant)?,
            up_proj: build_qlinear(tensors, &format!("{prefix}.up_proj"), quant)?,
            down_proj: build_qlinear(tensors, &format!("{prefix}.down_proj"), quant)?,
        })
    }

    fn forward(&mut self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
        let gate = nn::gelu(&self.gate_proj.forward(x)?)?; // Gemma 3 uses GELU-tanh, not SiLU
        let up = self.up_proj.forward(x)?;
        self.down_proj.forward(&ops::multiply(&gate, &up)?)
    }
}

// ─── Attention (GQA with partial RoPE) ──────────────────────────────────

struct GemmaAttention {
    q_proj: QLinear,
    k_proj: QLinear,
    v_proj: QLinear,
    o_proj: QLinear,
    q_norm: GemmaRmsNorm,
    k_norm: GemmaRmsNorm,
    num_heads: usize,
    num_kv_heads: usize,
    head_dim: usize,
    rope_theta: f32,
    /// Number of head-dim entries that receive rotary (rest pass through).
    rope_dim: usize,
}

impl GemmaAttention {
    fn load(
        tensors: &HashMap<String, Array>,
        prefix: &str,
        cfg: &Gemma3Config,
        layer_kind: LayerKind,
    ) -> Result<Self, InferenceError> {
        let quant = Some(&cfg.quant);
        let (rope_theta, rope_dim) = match layer_kind {
            LayerKind::Sliding => (cfg.rope_theta_sliding, cfg.head_dim),
            LayerKind::Full => (
                cfg.rope_theta_global,
                ((cfg.head_dim as f32) * cfg.partial_rotary_factor).round() as usize,
            ),
        };
        Ok(Self {
            q_proj: build_qlinear(tensors, &format!("{prefix}.q_proj"), quant)?,
            k_proj: build_qlinear(tensors, &format!("{prefix}.k_proj"), quant)?,
            v_proj: build_qlinear(tensors, &format!("{prefix}.v_proj"), quant)?,
            o_proj: build_qlinear(tensors, &format!("{prefix}.o_proj"), quant)?,
            q_norm: GemmaRmsNorm::load(tensors, &format!("{prefix}.q_norm"), cfg.rms_norm_eps)?,
            k_norm: GemmaRmsNorm::load(tensors, &format!("{prefix}.k_norm"), cfg.rms_norm_eps)?,
            num_heads: cfg.num_attention_heads,
            num_kv_heads: cfg.num_key_value_heads,
            head_dim: cfg.head_dim,
            rope_theta,
            rope_dim,
        })
    }

    fn forward(
        &mut self,
        x: &Array,
        combined_mask: Option<&Array>,
    ) -> Result<Array, mlx_rs::error::Exception> {
        let s = x.shape();
        let (batch, seq_len) = (s[0], s[1]);

        let q = self.q_proj.forward(x)?;
        let k = self.k_proj.forward(x)?;
        let v = self.v_proj.forward(x)?;

        // Reshape to per-head: [B, T, H, D] → [B, H, T, D]
        let reshape_q = ops::reshape(
            &q,
            &[batch, seq_len, self.num_heads as i32, self.head_dim as i32],
        )?;
        let reshape_k = ops::reshape(
            &k,
            &[
                batch,
                seq_len,
                self.num_kv_heads as i32,
                self.head_dim as i32,
            ],
        )?;
        let reshape_v = ops::reshape(
            &v,
            &[
                batch,
                seq_len,
                self.num_kv_heads as i32,
                self.head_dim as i32,
            ],
        )?;
        let q = ops::transpose_axes(&reshape_q, &[0, 2, 1, 3])?;
        let k = ops::transpose_axes(&reshape_k, &[0, 2, 1, 3])?;
        let v = ops::transpose_axes(&reshape_v, &[0, 2, 1, 3])?;

        // Per-head RMS-norm on the last (head_dim) axis.
        let q = self.q_norm.forward(&q)?;
        let k = self.k_norm.forward(&k)?;

        // Apply RoPE. For partial rotary (global layers), split last axis into
        // [rotary_part, pass_through_part], rotate the first, then concat.
        let q = self.apply_rope(&q)?;
        let k = self.apply_rope(&k)?;

        // Repeat kv heads to match q heads.
        let n_rep = self.num_heads / self.num_kv_heads;
        let k = self.repeat_heads(&k, n_rep as i32, seq_len)?;
        let v = self.repeat_heads(&v, n_rep as i32, seq_len)?;

        // SDPA with scale 1/sqrt(head_dim) + (causal | combined) mask. Upstream
        // get_all_hidden_states passes `combined_mask = causal + pad_mask` so
        // queries never attend to left-padding keys; falling back to a plain
        // causal mask when no pad info is available keeps stand-alone generate
        // callers working.
        let scale = Array::from_f32(1.0 / (self.head_dim as f32).sqrt());
        let scores = ops::multiply(
            &ops::matmul(&q, &ops::transpose_axes(&k, &[0, 1, 3, 2])?)?,
            &scale,
        )?;
        let mask = match combined_mask {
            Some(m) => m.as_dtype(scores.dtype())?,
            None => build_causal_mask_additive(seq_len, scores.dtype())?,
        };
        let masked = ops::add(&scores, &mask)?;
        let attn = ops::softmax_axis(&masked, -1, None)?;
        let out = ops::matmul(&attn, &v)?;

        // Merge heads: [B, H, T, D] → [B, T, H*D].
        let out = ops::transpose_axes(&out, &[0, 2, 1, 3])?;
        let merged = ops::reshape(
            &out,
            &[batch, seq_len, (self.num_heads * self.head_dim) as i32],
        )?;
        self.o_proj.forward(&merged)
    }

    fn apply_rope(&self, x: &Array) -> Result<Array, mlx_rs::error::Exception> {
        if self.rope_dim == self.head_dim {
            // Full rotary — apply to the whole head_dim axis.
            return mlx_rs::fast::rope(
                x,
                self.head_dim as i32,
                false,
                self.rope_theta,
                1.0,
                0,
                None::<&Array>,
            );
        }
        // Partial rotary: rotate first `rope_dim` entries, leave rest unchanged.
        let rd = self.rope_dim as i32;
        let head_dim = self.head_dim as i32;
        let rot = x.index((.., .., .., ..rd));
        let pass = x.index((.., .., .., rd..));
        let rotated = mlx_rs::fast::rope(&rot, rd, false, self.rope_theta, 1.0, 0, None::<&Array>)?;
        let combined = ops::concatenate_axis(&[&rotated, &pass], -1)?;
        // Shape preservation sanity (no-op if already correct).
        let s = combined.shape();
        debug_assert_eq!(s[s.len() - 1], head_dim);
        Ok(combined)
    }

    fn repeat_heads(
        &self,
        x: &Array,
        n_rep: i32,
        seq_len: i32,
    ) -> Result<Array, mlx_rs::error::Exception> {
        if n_rep == 1 {
            return Ok(x.clone());
        }
        // x: [B, H_kv, T, D] → expand + tile along H axis → [B, H_kv*n_rep, T, D]
        let s = x.shape();
        let (b, h_kv, _t, d) = (s[0], s[1], s[2], s[3]);
        // [B, H_kv, 1, T, D] → repeat along axis 2 → [B, H_kv, n_rep, T, D]
        let x5 = ops::reshape(x, &[b, h_kv, 1, seq_len, d])?;
        let tiled = ops::tile(&x5, &[1, 1, n_rep, 1, 1])?;
        ops::reshape(&tiled, &[b, h_kv * n_rep, seq_len, d])
    }
}

// ─── Transformer block (Gemma 3: 4 RMSNorms per block) ──────────────────

struct GemmaBlock {
    input_layernorm: GemmaRmsNorm,
    attn: GemmaAttention,
    post_attention_layernorm: GemmaRmsNorm,
    pre_feedforward_layernorm: GemmaRmsNorm,
    mlp: GemmaMlp,
    post_feedforward_layernorm: GemmaRmsNorm,
}

impl GemmaBlock {
    fn load(
        tensors: &HashMap<String, Array>,
        prefix: &str,
        cfg: &Gemma3Config,
        layer_kind: LayerKind,
    ) -> Result<Self, InferenceError> {
        let eps = cfg.rms_norm_eps;
        Ok(Self {
            input_layernorm: GemmaRmsNorm::load(
                tensors,
                &format!("{prefix}.input_layernorm"),
                eps,
            )?,
            attn: GemmaAttention::load(tensors, &format!("{prefix}.self_attn"), cfg, layer_kind)?,
            post_attention_layernorm: GemmaRmsNorm::load(
                tensors,
                &format!("{prefix}.post_attention_layernorm"),
                eps,
            )?,
            pre_feedforward_layernorm: GemmaRmsNorm::load(
                tensors,
                &format!("{prefix}.pre_feedforward_layernorm"),
                eps,
            )?,
            mlp: GemmaMlp::load(tensors, &format!("{prefix}.mlp"), Some(&cfg.quant))?,
            post_feedforward_layernorm: GemmaRmsNorm::load(
                tensors,
                &format!("{prefix}.post_feedforward_layernorm"),
                eps,
            )?,
        })
    }

    fn forward(
        &mut self,
        x: &Array,
        mask: Option<&Array>,
    ) -> Result<Array, mlx_rs::error::Exception> {
        // Attention half: residual + post_ln(attn(pre_ln(x)))
        let residual = x.clone();
        let h = self.input_layernorm.forward(x)?;
        let h = self.attn.forward(&h, mask)?;
        let h = self.post_attention_layernorm.forward(&h)?;
        let x = ops::add(&residual, &h)?;

        // FF half: residual + post_ln(mlp(pre_ln(x)))
        let residual = x.clone();
        let h = self.pre_feedforward_layernorm.forward(&x)?;
        let h = self.mlp.forward(&h)?;
        let h = self.post_feedforward_layernorm.forward(&h)?;
        ops::add(&residual, &h)
    }
}

// ─── Full text encoder ──────────────────────────────────────────────────

pub struct Gemma3TextEncoder {
    config: Gemma3Config,
    embed_tokens: QEmbedding,
    layers: Vec<GemmaBlock>,
    final_norm: GemmaRmsNorm,
    tokenizer: Tokenizer,
    /// Cached embedding scale factor `sqrt(hidden_size)`.
    embed_scale: f32,
}

// SAFETY: QLinear / QEmbedding / Array are only mutated from within LtxBackend's
// single-thread inference path, and this encoder lives there too.
unsafe impl Send for Gemma3TextEncoder {}
unsafe impl Sync for Gemma3TextEncoder {}

impl Gemma3TextEncoder {
    /// Try to load the encoder from a common HuggingFace cache location for
    /// `mlx-community/gemma-3-12b-it-4bit`. Returns `Ok(None)` when the cache
    /// isn't present so the caller can fall back to unconditional generation.
    pub fn try_load_default() -> Result<Option<Self>, InferenceError> {
        let snapshots = std::env::var("HOME")
            .map(PathBuf::from)
            .unwrap_or_default()
            .join(".cache/huggingface/hub/models--mlx-community--gemma-3-12b-it-4bit/snapshots");
        if !snapshots.exists() {
            return Ok(None);
        }
        let Some(snapshot) = std::fs::read_dir(&snapshots)
            .ok()
            .and_then(|entries| entries.flatten().find_map(|e| Some(e.path())))
        else {
            return Ok(None);
        };
        Self::load(&snapshot).map(Some)
    }

    pub fn load(model_dir: &Path) -> Result<Self, InferenceError> {
        info!(dir = %model_dir.display(), "loading Gemma 3 12B text encoder");

        let tokenizer_path = model_dir.join("tokenizer.json");
        let tokenizer = Tokenizer::from_file(&tokenizer_path)
            .map_err(|e| InferenceError::InferenceFailed(format!("load tokenizer: {e}")))?;

        // Concatenate both shards into a single tensor map.
        let mut tensors: HashMap<String, Array> = HashMap::new();
        for shard in [
            "model-00001-of-00002.safetensors",
            "model-00002-of-00002.safetensors",
        ] {
            let path = model_dir.join(shard);
            if !path.exists() {
                return Err(InferenceError::InferenceFailed(format!(
                    "missing Gemma shard: {}",
                    path.display()
                )));
            }
            let loaded = Array::load_safetensors(&path)
                .map_err(|e| InferenceError::InferenceFailed(format!("load {shard}: {e}")))?;
            for (k, v) in loaded {
                tensors.insert(k, v);
            }
        }
        info!(tensors = tensors.len(), "Gemma shards loaded");

        let cfg = Gemma3Config::default();
        let text_pfx = "language_model.model";

        let embed_tokens = build_qembedding(
            &tensors,
            &format!("{text_pfx}.embed_tokens"),
            Some(&cfg.quant),
        )?;
        let final_norm =
            GemmaRmsNorm::load(&tensors, &format!("{text_pfx}.norm"), cfg.rms_norm_eps)?;

        let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
        for i in 0..cfg.num_hidden_layers {
            let kind = cfg.layer_types[i];
            let block = GemmaBlock::load(&tensors, &format!("{text_pfx}.layers.{i}"), &cfg, kind)?;
            layers.push(block);
        }
        info!(layers = layers.len(), "Gemma 3 decoder blocks loaded");

        let embed_scale = (cfg.hidden_size as f32).sqrt();
        Ok(Self {
            config: cfg,
            embed_tokens,
            layers,
            final_norm,
            tokenizer,
            embed_scale,
        })
    }

    /// Encode `prompt` into the `[1, T, 188160]` feature tensor consumed by
    /// the LTX-2.3 TextEmbeddingConnector.
    ///
    /// `max_tokens` caps the tokenized sequence so we don't run 1024 tokens
    /// through 48 transformer layers on every denoising step. 64 is plenty
    /// for typical prompts.
    pub fn encode_for_ltx(
        &mut self,
        prompt: &str,
        max_tokens: usize,
    ) -> Result<(Array, usize), InferenceError> {
        let map_err = |e: mlx_rs::error::Exception| InferenceError::InferenceFailed(e.to_string());

        // Tokenize + LEFT-pad to a fixed length for the connector.
        //
        // Upstream ltx-2-mlx (`GemmaFeaturesExtractorV2`) uses HF's
        // `tokenizer(..., padding='max_length', padding_side='left')` which
        // places pad tokens BEFORE the real prompt tokens. A prior revision
        // right-padded via `ids.resize(max_tokens, 0)` — the pad embedding
        // is the same on both sides, but because real tokens landed at
        // different sequence positions (0..N vs max_tokens-N..max_tokens),
        // every subsequent Gemma block saw different position-dependent
        // attention patterns. The divergence compounded across 48 layers
        // and through the connector → LTX transformer → VAE.
        let encoding = self
            .tokenizer
            .encode(prompt, true)
            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
        let raw_ids: Vec<i32> = encoding.get_ids().iter().map(|&id| id as i32).collect();
        let n_valid = raw_ids.len().min(max_tokens);
        let ids: Vec<i32> = if raw_ids.len() >= max_tokens {
            raw_ids[..max_tokens].to_vec()
        } else {
            let pad_count = max_tokens - n_valid;
            let mut padded = vec![0i32; max_tokens];
            padded[pad_count..].copy_from_slice(&raw_ids);
            padded
        };
        let seq_len = ids.len() as i32;
        let token_ids = Array::from_slice(&ids, &[1, seq_len]);

        // Embed + scale by sqrt(hidden_size). This is the "embedding layer"
        // hidden state (first of the 49).
        let mut hidden = self.embed_tokens.forward(&token_ids).map_err(map_err)?;
        let scale = Array::from_f32(self.embed_scale);
        hidden = ops::multiply(&hidden, &scale).map_err(map_err)?;

        let dump_dir = std::env::var("CAR_DUMP_GEMMA_HIDDEN").ok();
        if let Some(ref dir) = dump_dir {
            std::fs::create_dir_all(dir).ok();
            info!(
                tokens = ?ids.iter().take(10).copied().collect::<Vec<_>>(),
                "Gemma parity dump: prompt tokens (first 10)"
            );
            dump_hidden(dir, "hidden_000_embed", &hidden);
        }

        let mut all_hidden: Vec<Array> = Vec::with_capacity(self.config.num_hidden_layers + 1);
        all_hidden.push(hidden.clone());

        // Build combined causal + pad mask. Upstream
        // (base_encoder.get_all_hidden_states) builds:
        //   causal_mask = triu(full(T,T, -inf), k=1)
        //   pad_mask    = (1 - attention_mask)[:, None, None, :] * -inf
        //   combined    = causal_mask + pad_mask
        // and passes the same tensor to every decoder layer. With left-padding,
        // the first (T - n_valid) key positions are padding; without this mask
        // the real queries attend to padding and every downstream hidden state
        // is polluted (gemma_hidden_final max|Δ| = 5.2e5 vs ref).
        // Match upstream mask constants exactly: -1e9 (not -inf) so bf16
        // casts don't overflow. Upstream base_encoder.py uses
        //   causal_mask = triu(full(T,T, -1e9), k=1)
        //   pad_mask    = (1 - attention_mask) * -1e9
        let t = seq_len as usize;
        let pad_count = t.saturating_sub(n_valid);
        const NEG_BIG: f32 = -1.0e9;
        let mut mask_vals = vec![0.0f32; t * t];
        for i in 0..t {
            for j in 0..t {
                let causal = if j <= i { 0.0 } else { NEG_BIG };
                let pad = if j < pad_count { NEG_BIG } else { 0.0 };
                mask_vals[i * t + j] = causal + pad;
            }
        }
        let combined_mask = Array::from_slice(&mask_vals, &[1, 1, t as i32, t as i32]);

        for (i, block) in self.layers.iter_mut().enumerate() {
            hidden = block
                .forward(&hidden, Some(&combined_mask))
                .map_err(map_err)?;
            all_hidden.push(hidden.clone());
            if let Some(ref dir) = dump_dir {
                dump_hidden(dir, &format!("hidden_{:03}_block{:02}", i + 1, i), &hidden);
            }
        }
        // Final norm applies to the last hidden state; don't overwrite the
        // captured intermediate (upstream stacks UN-final-normed hidden states).
        let final_hidden = self.final_norm.forward(&hidden).map_err(map_err)?;
        if let Some(ref dir) = dump_dir {
            dump_hidden(dir, "hidden_final_norm", &final_hidden);
        }
        // LTX parity hook: mirror ref_ltx.py's `gemma_hidden_final` stage.
        // Upstream ltx-2-mlx dumps all_states[-1] — i.e. the LAST un-finalnormed
        // block output — not the post-final-norm tensor. So we dump `hidden`
        // (the pre-final-norm value), not `final_hidden`. This matches the
        // Python monkey-patch in `patched_encode_all`.
        let ltx_dump_dir = std::env::var("CAR_DUMP_LTX_STAGE").ok();
        if let Some(ref dir) = ltx_dump_dir {
            let _ = std::fs::create_dir_all(dir);
            if let Ok(h_f32) = hidden.as_dtype(mlx_rs::Dtype::Float32) {
                let _ = mlx_rs::transforms::eval([&h_f32]);
                let shape = h_f32.shape().to_vec();
                let data: &[f32] = h_f32.as_slice();
                let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
                let _ = std::fs::write(format!("{dir}/gemma_hidden_final.bin"), &bytes);
                let _ = std::fs::write(
                    format!("{dir}/gemma_hidden_final.meta"),
                    format!("{shape:?}\n"),
                );
            }
        }
        let _ = final_hidden;

        // Stack on a new last axis: [B, T, 3840, 49]
        let stacked = {
            let with_new_axis: Vec<Array> = all_hidden
                .iter()
                .map(|h| {
                    // [B, T, 3840] → [B, T, 3840, 1]
                    let s = h.shape();
                    ops::expand_dims(h, -1).unwrap_or_else(|_| {
                        ops::reshape(h, &[s[0], s[1], s[2], 1]).expect("reshape")
                    })
                })
                .collect();
            let refs: Vec<&Array> = with_new_axis.iter().collect();
            ops::concatenate_axis(&refs, -1).map_err(map_err)?
        };

        // Per-token RMS-norm over the 3840 axis (axis=2).
        let eps = Array::from_f32(self.config.rms_norm_eps);
        let var = ops::multiply(&stacked, &stacked)
            .map_err(map_err)?
            .mean_axes(&[2], true)
            .map_err(map_err)?;
        let scale = ops::rsqrt(&ops::add(&var, &eps).map_err(map_err)?).map_err(map_err)?;
        let normed = ops::multiply(&stacked, &scale).map_err(map_err)?;

        // Reshape to [B, T, 3840 * 49 = 188160].
        let s = normed.shape();
        let feat_dim = (self.config.hidden_size * (self.config.num_hidden_layers + 1)) as i32;
        let mut out = ops::reshape(&normed, &[s[0], s[1], feat_dim]).map_err(map_err)?;

        // Upstream `GemmaFeaturesExtractorV2.__call__` (per_token_rms branch)
        // zeros out padding positions after the per-token RMS normalization:
        //   mask_3d = attention_mask[:, :, None]; stacked = stacked * mask_3d
        // For our left-padded input, the first (T - n_valid) positions are
        // padding. Without this, the 1024-token stacked tensor has unit-variance
        // at EVERY position (including the 1017 pad positions for a short
        // prompt), blowing the mean_abs up by ~150× vs upstream and driving
        // aggregate_embed way out of distribution.
        let t = s[1] as usize;
        if n_valid < t {
            let pad_count = t - n_valid;
            let mask_vals: Vec<f32> = (0..t)
                .map(|i| if i < pad_count { 0.0 } else { 1.0 })
                .collect();
            let mask = Array::from_slice(&mask_vals, &[1, t as i32, 1]);
            out = ops::multiply(&out, &mask).map_err(map_err)?;
        }
        // LTX parity hook: mirror ref_ltx.py's `gemma_stacked` stage — the
        // per-token-RMS'd, flattened [B, T, 188160] input to aggregate_embed.
        if let Ok(dir) = std::env::var("CAR_DUMP_LTX_STAGE") {
            let _ = std::fs::create_dir_all(&dir);
            if let Ok(h_f32) = out.as_dtype(mlx_rs::Dtype::Float32) {
                let _ = mlx_rs::transforms::eval([&h_f32]);
                let shape = h_f32.shape().to_vec();
                let data: &[f32] = h_f32.as_slice();
                let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
                let _ = std::fs::write(format!("{dir}/gemma_stacked.bin"), &bytes);
                let _ = std::fs::write(format!("{dir}/gemma_stacked.meta"), format!("{shape:?}\n"));
            }
        }
        Ok((out, n_valid))
    }
}