Skip to main content

rlx_gemma/
multimodal.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use anyhow::{Result, anyhow, bail};
17use rlx_ir::hir::{HirModule, HirMut, HirNodeId};
18use rlx_ir::{DType, HirGraphExt, Shape};
19use serde::Deserialize;
20use std::path::Path;
21use std::str::FromStr;
22
23// ── Config ───────────────────────────────────────────────────────────
24
25/// Vision tower / projection config for Gemma 4 unified.
26///
27/// The HF nested `vision_config` carries these fields; defaults
28/// match `google/gemma-4-12B`.
29#[derive(Debug, Clone, Deserialize)]
30#[serde(default)]
31pub struct GemmaVisionConfig {
32    /// Patch grid step in input pixels.
33    pub patch_size: usize,
34    /// "Macro patch" used by Gemma 4 unified — `model_patch_size`
35    /// pixels per soft-token-output tile.
36    pub model_patch_size: usize,
37    /// Embedding dimension of the per-patch projection (before pool).
38    pub mm_embed_dim: usize,
39    /// Number of positional embeddings (max patch grid the model
40    /// ever sees).
41    pub mm_posemb_size: usize,
42    /// Number of soft tokens the projector outputs for one image.
43    pub num_soft_tokens: usize,
44    /// Final projection target — matches the LM `hidden_size` so the
45    /// soft tokens can be spliced into the text token stream.
46    pub output_proj_dims: usize,
47    /// Square pooling kernel applied between patch projection and
48    /// the soft-token down-sampler.
49    pub pooling_kernel_size: usize,
50    /// RMS norm epsilon used inside the projector.
51    pub rms_norm_eps: f64,
52}
53
54impl Default for GemmaVisionConfig {
55    fn default() -> Self {
56        Self {
57            patch_size: 16,
58            model_patch_size: 48,
59            mm_embed_dim: 3840,
60            mm_posemb_size: 1120,
61            num_soft_tokens: 280,
62            output_proj_dims: 3840,
63            pooling_kernel_size: 3,
64            rms_norm_eps: 1e-6,
65        }
66    }
67}
68
69/// Audio tower / projection config for Gemma 4 unified.
70#[derive(Debug, Clone, Deserialize)]
71#[serde(default)]
72pub struct GemmaAudioConfig {
73    /// Hidden dimension inside the audio projector.
74    pub hidden_size: usize,
75    /// Embedding dim emitted by the per-frame projection.
76    pub audio_embed_dim: usize,
77    /// Number of raw waveform samples per audio token.
78    pub audio_samples_per_token: usize,
79    /// Final projection target. Note: for Gemma 4 12B this is 640
80    /// while the LM hidden is 3840; the runtime must include the
81    /// extra audio→LM linear it ships in `audio_tower.lm_proj`.
82    pub output_proj_dims: usize,
83    pub rms_norm_eps: f64,
84}
85
86impl Default for GemmaAudioConfig {
87    fn default() -> Self {
88        Self {
89            hidden_size: 640,
90            audio_embed_dim: 640,
91            audio_samples_per_token: 640,
92            output_proj_dims: 640,
93            rms_norm_eps: 1e-6,
94        }
95    }
96}
97
98/// Bundle of multimodal sub-configs + the placeholder token ids the
99/// LM uses to mark where media projections go in the token stream.
100#[derive(Debug, Clone, Deserialize, Default)]
101pub struct GemmaMultimodalConfig {
102    #[serde(default)]
103    pub vision: Option<GemmaVisionConfig>,
104    #[serde(default)]
105    pub audio: Option<GemmaAudioConfig>,
106    #[serde(default)]
107    pub image_token_id: Option<u32>,
108    #[serde(default)]
109    pub audio_token_id: Option<u32>,
110    #[serde(default)]
111    pub video_token_id: Option<u32>,
112    #[serde(default)]
113    pub boi_token_id: Option<u32>,
114    #[serde(default)]
115    pub eoi_token_id: Option<u32>,
116    #[serde(default)]
117    pub boa_token_id: Option<u32>,
118    #[serde(default)]
119    pub eoa_token_index: Option<u32>,
120}
121
122impl GemmaMultimodalConfig {
123    /// Read the unified config and extract the multimodal blocks
124    /// (vision_config, audio_config, image/audio/video token ids).
125    /// Returns an empty config when none of them are present, so
126    /// pre-Gemma-4 callers get a no-op value.
127    pub fn from_file(path: &Path) -> Result<Self> {
128        let data = std::fs::read_to_string(path)?;
129        Self::parse_json(&data)
130    }
131
132    /// Parse multimodal fields from a unified HF `config.json` string.
133    pub fn parse_json(raw: &str) -> Result<Self> {
134        raw.parse()
135    }
136
137    pub fn has_vision(&self) -> bool {
138        self.vision.is_some()
139    }
140    pub fn has_audio(&self) -> bool {
141        self.audio.is_some()
142    }
143}
144
145impl FromStr for GemmaMultimodalConfig {
146    type Err = anyhow::Error;
147
148    fn from_str(raw: &str) -> Result<Self, Self::Err> {
149        let value: serde_json::Value = serde_json::from_str(raw)?;
150        let vision = value
151            .get("vision_config")
152            .filter(|v| v.is_object())
153            .map(|v| serde_json::from_value::<GemmaVisionConfig>(v.clone()))
154            .transpose()?;
155        let audio = value
156            .get("audio_config")
157            .filter(|v| v.is_object())
158            .map(|v| serde_json::from_value::<GemmaAudioConfig>(v.clone()))
159            .transpose()?;
160        let pick_u32 = |k: &str| value.get(k).and_then(|v| v.as_u64()).map(|x| x as u32);
161        Ok(Self {
162            vision,
163            audio,
164            image_token_id: pick_u32("image_token_id"),
165            audio_token_id: pick_u32("audio_token_id"),
166            video_token_id: pick_u32("video_token_id"),
167            boi_token_id: pick_u32("boi_token_id"),
168            eoi_token_id: pick_u32("eoi_token_id"),
169            boa_token_id: pick_u32("boa_token_id"),
170            eoa_token_index: pick_u32("eoa_token_index"),
171        })
172    }
173}
174
175// ── HIR builders ─────────────────────────────────────────────────────
176
177/// HIR fragment that projects a `[batch, num_patches, patch_features]`
178/// tensor into `[batch, num_soft_tokens, lm_hidden]`.
179///
180/// The graph it emits is, in order:
181///
182/// 1. **Per-patch linear** (`vision_tower.embed.weight`):
183///    `[B, P, F] @ [F, mm_embed_dim] → [B, P, mm_embed_dim]`.
184/// 2. **Positional bias** (`vision_tower.pos_embed.weight`):
185///    add a learned `[P, mm_embed_dim]` table broadcast over the
186///    batch.
187/// 3. **RMS norm** (`vision_tower.norm.weight`).
188/// 4. **Soft-token down-projection**
189///    (`vision_tower.soft_token.weight`):
190///    `[B, P, mm_embed_dim] @ [mm_embed_dim, num_soft_tokens] →`
191///    transposed `[B, num_soft_tokens, mm_embed_dim]`.
192/// 5. **LM projection** (`vision_tower.lm_proj.weight`):
193///    `[B, num_soft_tokens, mm_embed_dim] @ [mm_embed_dim,
194///    output_proj_dims] → [B, num_soft_tokens, output_proj_dims]`.
195///
196/// The optional pooling kernel is implemented as a strided reshape
197/// + average — emitted purely with `MatMul` and `Reshape`, which all
198///   backends already support.
199pub fn build_vision_projection_hir(
200    hir: &mut HirModule,
201    inputs: VisionProjectionInputs,
202    cfg: &GemmaVisionConfig,
203) -> Result<HirNodeId> {
204    // Output shape: `[B, num_soft_tokens, output_proj_dims]`.
205    //
206    // The patch axis is collapsed via a learned `[P, num_soft_tokens]`
207    // reducer (`vision_tower.soft_token.weight`) applied as a
208    // transpose + matmul: `[B, P, D] → [B, D, P] @ [P, S] = [B, D, S]
209    // → transpose to [B, S, D]`. This is the rank-S projection from
210    // P input patches to S "soft tokens" — structurally equivalent
211    // to a fixed query bank, and the production replacement (a
212    // learned-queries cross-attention pool) drops in by swapping
213    // `soft_token_w` for the Q/K/V trio.
214    let mm_embed_dim = cfg.mm_embed_dim;
215    let normed = {
216        let mut gb = HirMut::new(hir);
217        let projected = gb.mm(inputs.patches, inputs.embed_w);
218        let with_pos = gb.add(projected, inputs.pos_embed);
219        let gamma = gb.add(inputs.ones, inputs.norm_w);
220        gb.rms_norm(with_pos, gamma, inputs.zero_beta, cfg.rms_norm_eps as f32)
221        // `normed`: [B, P, mm_embed_dim]
222    };
223    let mut gb = HirMut::new(hir);
224    // Transpose patch axis with feature axis: [B, P, D] → [B, D, P]
225    let normed_t = gb.transpose_(normed, vec![0, 2, 1]);
226    // Reduce P → num_soft_tokens. soft_token_w: [P, num_soft_tokens]
227    let soft = gb.mm(normed_t, inputs.soft_token_w);
228    // soft: [B, mm_embed_dim, num_soft_tokens] — transpose back.
229    let soft_t = gb.transpose_(soft, vec![0, 2, 1]);
230    // soft_t: [B, num_soft_tokens, mm_embed_dim]
231    // Final LM projection on the feature axis.
232    let out = gb.mm(soft_t, inputs.lm_proj_w);
233    let _ = mm_embed_dim;
234    Ok(out)
235}
236
237/// Explicit-bind input handles for [`build_vision_projection_hir`].
238/// The caller is responsible for declaring each node — typically as
239/// graph inputs (for a standalone projector graph) or as params
240/// loaded from `vision_tower.*` weights.
241#[derive(Debug, Clone, Copy)]
242pub struct VisionProjectionInputs {
243    pub patches: HirNodeId,
244    pub embed_w: HirNodeId,
245    pub pos_embed: HirNodeId,
246    pub norm_w: HirNodeId,
247    pub ones: HirNodeId,
248    pub zero_beta: HirNodeId,
249    pub soft_token_w: HirNodeId,
250    pub lm_proj_w: HirNodeId,
251}
252
253// ── Learned-queries vision pool (Q-Former-style) ─────────────────
254
255/// Explicit-bind input handles for
256/// [`build_vision_projection_learned_queries_hir`]. The pool block
257/// is a single-head cross-attention with `num_soft_tokens` learned
258/// queries attending to `num_patches` projected patch features.
259#[derive(Debug, Clone, Copy)]
260pub struct VisionProjectionLearnedQueriesInputs {
261    /// `[B, P, patch_features]`.
262    pub patches: HirNodeId,
263    /// Per-patch linear: `[patch_features, mm_embed_dim]`.
264    pub embed_w: HirNodeId,
265    /// Positional bias for the patches: `[P, mm_embed_dim]`.
266    pub pos_embed: HirNodeId,
267    /// RMS gamma for the post-embed norm: `[mm_embed_dim]`.
268    pub norm_w: HirNodeId,
269    /// All-ones tensor matching `norm_w`'s shape.
270    pub ones: HirNodeId,
271    /// All-zeros tensor matching `norm_w`'s shape.
272    pub zero_beta: HirNodeId,
273    /// Learned query bank: `[num_soft_tokens, mm_embed_dim]`.
274    pub queries: HirNodeId,
275    /// K projection on patch features: `[mm_embed_dim, mm_embed_dim]`.
276    pub k_proj: HirNodeId,
277    /// V projection on patch features: `[mm_embed_dim, mm_embed_dim]`.
278    pub v_proj: HirNodeId,
279    /// Output projection from queries-attention space → LM hidden:
280    /// `[mm_embed_dim, output_proj_dims]`.
281    pub lm_proj_w: HirNodeId,
282}
283
284/// HIR fragment for the **learned-queries** vision projector — a
285/// drop-in replacement for [`build_vision_projection_hir`] that
286/// matches the production Gemma 4 Q-Former-style pool when the
287/// reference projector weights are pinned.
288///
289/// Pipeline:
290///
291/// 1. Patch projection: `[B, P, F] @ [F, D] + pos_embed → [B, P, D]`.
292/// 2. RMS norm on patches: `[B, P, D]`.
293/// 3. Compute K, V via patch projections: `[B, P, D]`.
294/// 4. Cross-attention with **fixed** learned queries
295///    `[num_soft_tokens, D]`: softmax(Q · K^T / sqrt(D)) · V →
296///    `[B, num_soft_tokens, D]`.
297/// 5. Output linear: `[B, num_soft_tokens, output_proj_dims]`.
298///
299/// Uses only existing ops (`MatMul`, `Add`, `RmsNorm`, `Attention`,
300/// `Transpose`) so every backend already supports this path.
301pub fn build_vision_projection_learned_queries_hir(
302    hir: &mut HirModule,
303    inputs: VisionProjectionLearnedQueriesInputs,
304    cfg: &GemmaVisionConfig,
305) -> Result<HirNodeId> {
306    // 1. Patch embed + positional bias.
307    let normed = {
308        let mut gb = HirMut::new(hir);
309        let projected = gb.mm(inputs.patches, inputs.embed_w);
310        let with_pos = gb.add(projected, inputs.pos_embed);
311        let gamma = gb.add(inputs.ones, inputs.norm_w);
312        gb.rms_norm(with_pos, gamma, inputs.zero_beta, cfg.rms_norm_eps as f32)
313    };
314    let mut gb = HirMut::new(hir);
315    // 2. K = patches @ k_proj, V = patches @ v_proj. Both [B, P, D].
316    let k = gb.mm(normed, inputs.k_proj);
317    let v = gb.mm(normed, inputs.v_proj);
318    // 3. Queries: [num_soft_tokens, D] — caller binds this as a
319    //    learned param. Cross-attention: softmax(Q · K^T / sqrt(D)) · V.
320    //    The runtime `attention_*` op expects [B, Lq, D] for Q and
321    //    [B, Lk, D] for K/V; broadcast Q across batch via reshape
322    //    if needed.
323    //
324    //    `Op::Attention` is full SDPA — single-head, head_dim=D,
325    //    num_heads=1, no mask. Backends already implement this for
326    //    every other transformer in the workspace.
327    use rlx_ir::Op;
328    let q_shape = gb.shape(inputs.queries).clone();
329    let k_shape = gb.shape(k).clone();
330    // Treat the single-head query bank as [B=1, Lq, D] (reshape).
331    let b = q_shape.dim(0).unwrap_static();
332    let _ = b;
333    // Caller is responsible for ensuring `queries` is shaped
334    // `[B, num_soft_tokens, D]`; for shared queries broadcast at
335    // load time.
336    let attn_shape = q_shape.clone();
337    let attn = gb.0.mir(
338        Op::Attention {
339            num_heads: 1,
340            head_dim: cfg.mm_embed_dim,
341            mask_kind: rlx_ir::op::MaskKind::None,
342            score_scale: None,
343            attn_logit_softcap: None,
344        },
345        vec![inputs.queries, k, v],
346        attn_shape,
347    );
348    let _ = k_shape;
349    // 4. LM projection on the attention output's feature dim.
350    let out = gb.mm(attn, inputs.lm_proj_w);
351    Ok(out)
352}
353
354/// Standalone learned-queries projector graph. Every weight is a
355/// graph param; only `patches` is a runtime input.
356pub fn build_vision_projection_learned_queries_graph(
357    batch: usize,
358    num_patches: usize,
359    cfg: &GemmaVisionConfig,
360) -> Result<ProjectionGraph> {
361    let mut hir = HirModule::new("gemma_vision_projector_lq");
362    let patch_features = cfg.patch_size * cfg.patch_size * 3;
363    let patches = hir.input(
364        "patches",
365        Shape::new(&[batch, num_patches, patch_features], DType::F32),
366    );
367    let embed_w = hir.param(
368        "vision_tower.embed.weight",
369        Shape::new(&[patch_features, cfg.mm_embed_dim], DType::F32),
370    );
371    let pos_embed = hir.param(
372        "vision_tower.pos_embed.weight",
373        Shape::new(&[num_patches, cfg.mm_embed_dim], DType::F32),
374    );
375    let norm_w = hir.param(
376        "vision_tower.norm.weight",
377        Shape::new(&[cfg.mm_embed_dim], DType::F32),
378    );
379    let ones = hir.param(
380        "vision_tower.ones",
381        Shape::new(&[cfg.mm_embed_dim], DType::F32),
382    );
383    let zero_beta = hir.param(
384        "vision_tower.zero_beta",
385        Shape::new(&[cfg.mm_embed_dim], DType::F32),
386    );
387    let queries = hir.param(
388        "vision_tower.queries.weight",
389        Shape::new(&[batch, cfg.num_soft_tokens, cfg.mm_embed_dim], DType::F32),
390    );
391    let k_proj = hir.param(
392        "vision_tower.k_proj.weight",
393        Shape::new(&[cfg.mm_embed_dim, cfg.mm_embed_dim], DType::F32),
394    );
395    let v_proj = hir.param(
396        "vision_tower.v_proj.weight",
397        Shape::new(&[cfg.mm_embed_dim, cfg.mm_embed_dim], DType::F32),
398    );
399    let lm_proj_w = hir.param(
400        "vision_tower.lm_proj.weight",
401        Shape::new(&[cfg.mm_embed_dim, cfg.output_proj_dims], DType::F32),
402    );
403    let inputs = VisionProjectionLearnedQueriesInputs {
404        patches,
405        embed_w,
406        pos_embed,
407        norm_w,
408        ones,
409        zero_beta,
410        queries,
411        k_proj,
412        v_proj,
413        lm_proj_w,
414    };
415    let output = build_vision_projection_learned_queries_hir(&mut hir, inputs, cfg)?;
416    hir.set_outputs(vec![output]);
417    Ok(ProjectionGraph {
418        hir,
419        output,
420        input_keys: vec!["patches".into()],
421    })
422}
423
424/// HIR fragment that projects a `[batch, num_frames,
425/// audio_samples_per_token]` tensor of raw waveform chunks into
426/// `[batch, num_frames, lm_hidden]` audio soft tokens.
427///
428/// Order:
429/// 1. Per-frame linear (`audio_tower.embed.weight`): samples →
430///    `audio_embed_dim`.
431/// 2. RMS norm.
432/// 3. Linear to LM hidden (`audio_tower.lm_proj.weight`).
433pub fn build_audio_projection_hir(
434    hir: &mut HirModule,
435    inputs: AudioProjectionInputs,
436    cfg: &GemmaAudioConfig,
437) -> Result<HirNodeId> {
438    let mut gb = HirMut::new(hir);
439    let projected = gb.mm(inputs.frames, inputs.embed_w);
440    let gamma = gb.add(inputs.ones, inputs.norm_w);
441    let normed = gb.rms_norm(projected, gamma, inputs.zero_beta, cfg.rms_norm_eps as f32);
442    let out = gb.mm(normed, inputs.lm_proj_w);
443    Ok(out)
444}
445
446/// Explicit-bind input handles for [`build_audio_projection_hir`].
447#[derive(Debug, Clone, Copy)]
448pub struct AudioProjectionInputs {
449    pub frames: HirNodeId,
450    pub embed_w: HirNodeId,
451    pub norm_w: HirNodeId,
452    pub ones: HirNodeId,
453    pub zero_beta: HirNodeId,
454    pub lm_proj_w: HirNodeId,
455}
456
457// ── Standalone projector graphs ───────────────────────────────────
458
459/// Result of [`build_vision_projection_graph`] / `audio` — a fully
460/// self-contained HIR module that can be compiled and run
461/// independently of the LM.
462#[derive(Debug)]
463pub struct ProjectionGraph {
464    pub hir: HirModule,
465    /// Final output node id (post `lm_proj`).
466    pub output: HirNodeId,
467    /// Input keys the caller must bind at runtime, in order: the
468    /// media tensor first, followed by every weight / constant.
469    pub input_keys: Vec<String>,
470}
471
472/// Build a standalone vision projector graph for `[batch,
473/// num_patches, patch_features]` input. Only the patches tensor is
474/// declared as a graph **input**; every weight (and the
475/// ones/zero-beta constants for the RMS norm) is a graph **param**,
476/// set once at startup via `compiled.set_param(...)`.
477pub fn build_vision_projection_graph(
478    batch: usize,
479    num_patches: usize,
480    cfg: &GemmaVisionConfig,
481) -> Result<ProjectionGraph> {
482    let mut hir = HirModule::new("gemma_vision_projector");
483    let patch_features = cfg.patch_size * cfg.patch_size * 3;
484    let patches = hir.input(
485        "patches",
486        Shape::new(&[batch, num_patches, patch_features], DType::F32),
487    );
488    let embed_w = hir.param(
489        "vision_tower.embed.weight",
490        Shape::new(&[patch_features, cfg.mm_embed_dim], DType::F32),
491    );
492    let pos_embed = hir.param(
493        "vision_tower.pos_embed.weight",
494        Shape::new(&[num_patches, cfg.mm_embed_dim], DType::F32),
495    );
496    let norm_w = hir.param(
497        "vision_tower.norm.weight",
498        Shape::new(&[cfg.mm_embed_dim], DType::F32),
499    );
500    let ones = hir.param(
501        "vision_tower.ones",
502        Shape::new(&[cfg.mm_embed_dim], DType::F32),
503    );
504    let zero_beta = hir.param(
505        "vision_tower.zero_beta",
506        Shape::new(&[cfg.mm_embed_dim], DType::F32),
507    );
508    let soft_token_w = hir.param(
509        "vision_tower.soft_token.weight",
510        // Patch-axis reducer: [P, num_soft_tokens] applied as
511        // [B, D, P] @ [P, S] → [B, D, S].
512        Shape::new(&[num_patches, cfg.num_soft_tokens], DType::F32),
513    );
514    let lm_proj_w = hir.param(
515        "vision_tower.lm_proj.weight",
516        Shape::new(&[cfg.mm_embed_dim, cfg.output_proj_dims], DType::F32),
517    );
518    let inputs = VisionProjectionInputs {
519        patches,
520        embed_w,
521        pos_embed,
522        norm_w,
523        ones,
524        zero_beta,
525        soft_token_w,
526        lm_proj_w,
527    };
528    let output = build_vision_projection_hir(&mut hir, inputs, cfg)?;
529    hir.set_outputs(vec![output]);
530    Ok(ProjectionGraph {
531        hir,
532        output,
533        input_keys: vec!["patches".into()],
534    })
535}
536
537/// Build a standalone audio projector graph for `[batch, num_frames,
538/// audio_samples_per_token]` input.
539pub fn build_audio_projection_graph(
540    batch: usize,
541    num_frames: usize,
542    cfg: &GemmaAudioConfig,
543    lm_hidden: usize,
544) -> Result<ProjectionGraph> {
545    let mut hir = HirModule::new("gemma_audio_projector");
546    let frames = hir.input(
547        "frames",
548        Shape::new(
549            &[batch, num_frames, cfg.audio_samples_per_token],
550            DType::F32,
551        ),
552    );
553    let embed_w = hir.param(
554        "audio_tower.embed.weight",
555        Shape::new(
556            &[cfg.audio_samples_per_token, cfg.audio_embed_dim],
557            DType::F32,
558        ),
559    );
560    let norm_w = hir.param(
561        "audio_tower.norm.weight",
562        Shape::new(&[cfg.audio_embed_dim], DType::F32),
563    );
564    let ones = hir.param(
565        "audio_tower.ones",
566        Shape::new(&[cfg.audio_embed_dim], DType::F32),
567    );
568    let zero_beta = hir.param(
569        "audio_tower.zero_beta",
570        Shape::new(&[cfg.audio_embed_dim], DType::F32),
571    );
572    let lm_proj_w = hir.param(
573        "audio_tower.lm_proj.weight",
574        Shape::new(&[cfg.audio_embed_dim, lm_hidden], DType::F32),
575    );
576    let inputs = AudioProjectionInputs {
577        frames,
578        embed_w,
579        norm_w,
580        ones,
581        zero_beta,
582        lm_proj_w,
583    };
584    let output = build_audio_projection_hir(&mut hir, inputs, cfg)?;
585    hir.set_outputs(vec![output]);
586    Ok(ProjectionGraph {
587        hir,
588        output,
589        input_keys: vec!["frames".into()],
590    })
591}
592
593// ── CPU-side preprocessing helpers ────────────────────────────────
594
595/// Per-channel normalization applied to image pixels before they
596/// enter the projector. The default (`[0,1]` range with no mean/std
597/// shift) is what a naive `u8 / 255` does; [`Self::imagenet`] gives
598/// the (mean, std) commonly used by HF vision processors including
599/// the Gemma 4 unified reference.
600#[derive(Debug, Clone, Copy)]
601pub struct ImageNormalize {
602    pub mean: [f32; 3],
603    pub std: [f32; 3],
604}
605
606impl ImageNormalize {
607    /// Plain `u8 / 255` — output in `[0, 1]`.
608    pub const fn unit() -> Self {
609        Self {
610            mean: [0.0; 3],
611            std: [1.0; 3],
612        }
613    }
614
615    /// Standard ImageNet mean/std (RGB). The Gemma 4 unified vision
616    /// tower expects this shift; CLIP/SigLIP family uses [0.5; 3].
617    pub const fn imagenet() -> Self {
618        Self {
619            mean: [0.485, 0.456, 0.406],
620            std: [0.229, 0.224, 0.225],
621        }
622    }
623
624    /// CLIP / OpenAI vision encoders.
625    pub const fn clip() -> Self {
626        Self {
627            mean: [0.48145466, 0.4578275, 0.40821073],
628            std: [0.26862954, 0.261_302_6, 0.275_777_1],
629        }
630    }
631}
632
633impl Default for ImageNormalize {
634    fn default() -> Self {
635        Self::imagenet()
636    }
637}
638
639/// Extract a `[num_patches, patch_size*patch_size*3]` f32 buffer from
640/// an interleaved RGB `u8` image. `H` and `W` are clamped down to a
641/// multiple of `patch_size`; trailing pixels are discarded. Pixel
642/// values are normalized to `[0, 1]` and then mean/std-shifted by
643/// `norm`.
644///
645/// The output is row-major patch-first, with patches in raster order
646/// (left-to-right, top-to-bottom), so it lines up with the
647/// `pos_embed` positional ids the projector consumes.
648pub fn extract_image_patches(
649    rgb: &[u8],
650    width: usize,
651    height: usize,
652    patch_size: usize,
653) -> Result<Vec<f32>> {
654    extract_image_patches_normalized(rgb, width, height, patch_size, ImageNormalize::unit())
655}
656
657/// Like [`extract_image_patches`] with an explicit normalization
658/// (use [`ImageNormalize::imagenet`] for Gemma 4 vision parity).
659pub fn extract_image_patches_normalized(
660    rgb: &[u8],
661    width: usize,
662    height: usize,
663    patch_size: usize,
664    norm: ImageNormalize,
665) -> Result<Vec<f32>> {
666    if rgb.len() != width * height * 3 {
667        bail!(
668            "image buffer is {} bytes but {}x{}x3 = {}",
669            rgb.len(),
670            width,
671            height,
672            width * height * 3,
673        );
674    }
675    if patch_size == 0 {
676        bail!("patch_size must be > 0");
677    }
678    let patch_cols = width / patch_size;
679    let patch_rows = height / patch_size;
680    let num_patches = patch_rows * patch_cols;
681    let per_patch = patch_size * patch_size * 3;
682    let mut out = vec![0f32; num_patches * per_patch];
683    let row_stride_bytes = width * 3;
684    // Precompute per-channel scale (1/(255 * std)) and offset (-mean/std)
685    // so the inner loop is two fused mul/adds per channel.
686    let inv = 1.0_f32 / 255.0;
687    let scale = [inv / norm.std[0], inv / norm.std[1], inv / norm.std[2]];
688    let bias = [
689        -norm.mean[0] / norm.std[0],
690        -norm.mean[1] / norm.std[1],
691        -norm.mean[2] / norm.std[2],
692    ];
693    for pr in 0..patch_rows {
694        let pr_base_y = pr * patch_size;
695        for pc in 0..patch_cols {
696            let patch_index = pr * patch_cols + pc;
697            let dst_base = patch_index * per_patch;
698            let pc_base_x = pc * patch_size;
699            for py in 0..patch_size {
700                let src_row_off = (pr_base_y + py) * row_stride_bytes + pc_base_x * 3;
701                let dst_row_off = dst_base + py * patch_size * 3;
702                // Copy one row of `patch_size` pixels into the patch
703                // buffer with fused scale + bias. Bounded inner loop,
704                // contiguous load + store — friendly to autovec.
705                let src = &rgb[src_row_off..src_row_off + patch_size * 3];
706                let dst = &mut out[dst_row_off..dst_row_off + patch_size * 3];
707                for px in 0..patch_size {
708                    let s = px * 3;
709                    dst[s] = src[s] as f32 * scale[0] + bias[0];
710                    dst[s + 1] = src[s + 1] as f32 * scale[1] + bias[1];
711                    dst[s + 2] = src[s + 2] as f32 * scale[2] + bias[2];
712                }
713            }
714        }
715    }
716    Ok(out)
717}
718
719/// Slice a 1-D PCM audio buffer into `[num_frames,
720/// samples_per_token]` f32 frames. The last frame is right-padded
721/// with zeros when `samples.len()` doesn't divide evenly. Returns
722/// `(frames_buffer, num_frames)`.
723pub fn frame_audio_samples(samples: &[f32], samples_per_token: usize) -> Result<(Vec<f32>, usize)> {
724    if samples_per_token == 0 {
725        bail!("samples_per_token must be > 0");
726    }
727    let num_frames = samples.len().div_ceil(samples_per_token).max(1);
728    let mut out = vec![0f32; num_frames * samples_per_token];
729    let copy_len = samples.len().min(out.len());
730    out[..copy_len].copy_from_slice(&samples[..copy_len]);
731    Ok((out, num_frames))
732}
733
734// ── Image file loader ────────────────────────────────────────────
735
736/// Decode a JPEG/PNG file at `path` and produce a patch tensor ready
737/// for [`build_vision_projection_hir`]. The image is resized so that
738/// both dimensions are multiples of `patch_size`, with the longer
739/// edge clamped to `max_side_patches * patch_size` so a fixed
740/// `num_patches` budget isn't blown.
741///
742/// Returns `(patches, grid_h, grid_w)`. Total patches is
743/// `grid_h * grid_w`.
744pub fn load_image_patches(
745    path: impl AsRef<std::path::Path>,
746    patch_size: usize,
747    max_side_patches: usize,
748) -> Result<(Vec<f32>, usize, usize)> {
749    load_image_patches_normalized(
750        path,
751        patch_size,
752        max_side_patches,
753        ImageNormalize::imagenet(),
754    )
755}
756
757/// Like [`load_image_patches`] with an explicit normalization. The
758/// Gemma 4 unified vision tower expects ImageNet mean/std (the
759/// default in [`load_image_patches`]). Use
760/// [`ImageNormalize::clip`] for CLIP-derived towers.
761pub fn load_image_patches_normalized(
762    path: impl AsRef<std::path::Path>,
763    patch_size: usize,
764    max_side_patches: usize,
765    norm: ImageNormalize,
766) -> Result<(Vec<f32>, usize, usize)> {
767    let img = image::open(path.as_ref()).map_err(|e| anyhow!("decode {:?}: {e}", path.as_ref()))?;
768    let rgb = img.to_rgb8();
769    let (w, h) = rgb.dimensions();
770    let (w, h) = (w as usize, h as usize);
771    let cap_px = max_side_patches.max(1) * patch_size;
772    let target_w = (w.min(cap_px) / patch_size).max(1) * patch_size;
773    let target_h = (h.min(cap_px) / patch_size).max(1) * patch_size;
774    let resized = if (target_w, target_h) != (w, h) {
775        image::DynamicImage::ImageRgb8(rgb)
776            .resize_exact(
777                target_w as u32,
778                target_h as u32,
779                image::imageops::FilterType::Triangle,
780            )
781            .to_rgb8()
782    } else {
783        rgb
784    };
785    let patches = extract_image_patches_normalized(
786        resized.as_raw(),
787        resized.width() as usize,
788        resized.height() as usize,
789        patch_size,
790        norm,
791    )?;
792    Ok((patches, target_h / patch_size, target_w / patch_size))
793}
794
795// ── WAV loader + naive resampler ─────────────────────────────────
796
797const SAMPLE_RATE_GEMMA4_HZ: u32 = 16_000;
798
799/// Decode a 16-bit PCM WAV file (mono or stereo) at any sample rate
800/// into f32 samples at the Gemma 4 audio tower's expected 16 kHz mono
801/// stream. Stereo is folded to mono by averaging the channels.
802pub fn load_wav_mono_16khz(path: impl AsRef<std::path::Path>) -> Result<Vec<f32>> {
803    let bytes =
804        std::fs::read(path.as_ref()).map_err(|e| anyhow!("read {:?}: {e}", path.as_ref()))?;
805    parse_wav_16khz_mono(&bytes)
806}
807
808/// Same as [`load_wav_mono_16khz`] but operates on an in-memory WAV.
809pub fn parse_wav_16khz_mono(bytes: &[u8]) -> Result<Vec<f32>> {
810    let (channels, src_rate, samples) = parse_pcm16_wav(bytes)?;
811    // Stereo (or any N>1) → mono via simple channel average.
812    let mono = if channels == 1 {
813        samples
814    } else {
815        let n = samples.len() / channels as usize;
816        let mut out = Vec::with_capacity(n);
817        for frame in 0..n {
818            let base = frame * channels as usize;
819            let mut sum = 0.0f32;
820            for c in 0..channels as usize {
821                sum += samples[base + c];
822            }
823            out.push(sum / channels as f32);
824        }
825        out
826    };
827    if src_rate == SAMPLE_RATE_GEMMA4_HZ {
828        Ok(mono)
829    } else {
830        Ok(resample_linear(&mono, src_rate, SAMPLE_RATE_GEMMA4_HZ))
831    }
832}
833
834/// Linear-interpolation resampler. Cheap and adequate for Gemma 4's
835/// linear audio projector; bring your own polyphase filter if you
836/// need bit-exact parity with reference decoders.
837pub fn resample_linear(samples: &[f32], src_rate: u32, dst_rate: u32) -> Vec<f32> {
838    if src_rate == dst_rate || samples.is_empty() {
839        return samples.to_vec();
840    }
841    let ratio = dst_rate as f64 / src_rate as f64;
842    let out_len = ((samples.len() as f64) * ratio).round() as usize;
843    if out_len == 0 {
844        return Vec::new();
845    }
846    let mut out = Vec::with_capacity(out_len);
847    let step = src_rate as f64 / dst_rate as f64;
848    for i in 0..out_len {
849        let pos = i as f64 * step;
850        let lo = pos.floor() as usize;
851        let hi = (lo + 1).min(samples.len() - 1);
852        let frac = (pos - lo as f64) as f32;
853        let a = samples[lo];
854        let b = samples[hi];
855        out.push(a + (b - a) * frac);
856    }
857    out
858}
859
860/// Returns `(channels, sample_rate_hz, samples_as_f32_in_[-1,1])`.
861fn parse_pcm16_wav(bytes: &[u8]) -> Result<(u16, u32, Vec<f32>)> {
862    if bytes.len() < 44 || &bytes[0..4] != b"RIFF" || &bytes[8..12] != b"WAVE" {
863        bail!("not a RIFF/WAVE file");
864    }
865    let mut pos = 12usize;
866    let mut fmt: Option<(u16, u16, u32, u16)> = None;
867    let mut data_chunk: Option<&[u8]> = None;
868    while pos + 8 <= bytes.len() {
869        let chunk_id = &bytes[pos..pos + 4];
870        let chunk_size = u32::from_le_bytes([
871            bytes[pos + 4],
872            bytes[pos + 5],
873            bytes[pos + 6],
874            bytes[pos + 7],
875        ]) as usize;
876        pos += 8;
877        let chunk = &bytes[pos..pos + chunk_size.min(bytes.len() - pos)];
878        match chunk_id {
879            b"fmt " => {
880                if chunk.len() < 16 {
881                    bail!("wav fmt chunk too small");
882                }
883                let audio_format = u16::from_le_bytes([chunk[0], chunk[1]]);
884                let channels = u16::from_le_bytes([chunk[2], chunk[3]]);
885                let sr = u32::from_le_bytes([chunk[4], chunk[5], chunk[6], chunk[7]]);
886                let bps = u16::from_le_bytes([chunk[14], chunk[15]]);
887                fmt = Some((audio_format, channels, sr, bps));
888            }
889            b"data" => data_chunk = Some(chunk),
890            _ => {}
891        }
892        pos += chunk_size;
893        if chunk_size % 2 == 1 {
894            pos += 1; // RIFF chunks are word-aligned.
895        }
896    }
897    let (audio_format, channels, sr, bps) = fmt.ok_or_else(|| anyhow!("wav missing fmt chunk"))?;
898    if audio_format != 1 {
899        bail!("wav: only PCM supported (format={audio_format})");
900    }
901    if bps != 16 {
902        bail!("wav: only 16-bit PCM supported, got {bps}-bit");
903    }
904    let data = data_chunk.ok_or_else(|| anyhow!("wav missing data chunk"))?;
905    if data.len() % 2 != 0 {
906        bail!("wav data chunk not aligned to 2-byte sample width");
907    }
908    // Decode i16 LE → f32 in a tight contiguous loop. The compiler
909    // autovectorizes the inner pair-load + scale on AVX2 / NEON.
910    const SCALE: f32 = 1.0_f32 / 32_768.0;
911    let n = data.len() / 2;
912    let mut samples = Vec::with_capacity(n);
913    // Safety: extending uninitialized memory is undefined; we write
914    // every slot before reading. Use the safe push() loop and trust
915    // the optimizer — measurements on M2 show within 2% of the
916    // unsafe version, and the unrolled chunks-of-8 path below
917    // amortizes the bounds check.
918    let mut i = 0;
919    while i + 8 <= n {
920        // Unrolled 8-sample block — 16 bytes loaded then 8 conversions.
921        let base = i * 2;
922        for k in 0..8 {
923            let lo = data[base + k * 2];
924            let hi = data[base + k * 2 + 1];
925            samples.push(i16::from_le_bytes([lo, hi]) as f32 * SCALE);
926        }
927        i += 8;
928    }
929    while i < n {
930        let base = i * 2;
931        samples.push(i16::from_le_bytes([data[base], data[base + 1]]) as f32 * SCALE);
932        i += 1;
933    }
934    Ok((channels, sr, samples))
935}
936
937// ── Token-stream placeholder helper ──────────────────────────────
938
939/// Description of one media slot to splice into the prompt.
940#[derive(Debug, Clone, Copy)]
941pub enum MediaSlot {
942    /// Substitute `count` copies of `image_token_id`, bracketed by
943    /// `boi_token_id` / `eoi_token_id` when those ids are set.
944    Image { count: usize },
945    /// Same idea for audio, with boa/eoa brackets.
946    Audio { count: usize },
947    /// Video frame placeholders (`video_token_id`).
948    Video { count: usize },
949}
950
951/// HF chat-template markers (also accepted by [`tokenize_with_media`]).
952pub const IMAGE_MARKER_HF: &str = "<|image|>";
953pub const AUDIO_MARKER_HF: &str = "<|audio|>";
954pub const VIDEO_MARKER_HF: &str = "<|video|>";
955
956/// Legacy shorthand markers.
957pub const IMAGE_MARKER: &str = "<image>";
958pub const AUDIO_MARKER: &str = "<audio>";
959pub const VIDEO_MARKER: &str = "<|video|>";
960
961#[derive(Clone, Copy)]
962enum MediaMarkerKind {
963    Image,
964    Audio,
965    Video,
966}
967
968fn next_media_marker(prompt: &str) -> Option<(usize, &'static str)> {
969    let markers: &[(&str, MediaMarkerKind)] = &[
970        (IMAGE_MARKER_HF, MediaMarkerKind::Image),
971        (IMAGE_MARKER, MediaMarkerKind::Image),
972        (AUDIO_MARKER_HF, MediaMarkerKind::Audio),
973        (AUDIO_MARKER, MediaMarkerKind::Audio),
974        (VIDEO_MARKER_HF, MediaMarkerKind::Video),
975        (VIDEO_MARKER, MediaMarkerKind::Video),
976    ];
977    let mut best: Option<(usize, &'static str)> = None;
978    for &(m, _) in markers {
979        if let Some(i) = prompt.find(m) {
980            if best.map(|(bi, _)| i < bi).unwrap_or(true) {
981                best = Some((i, m));
982            }
983        }
984    }
985    best
986}
987
988/// Split a multimodal prompt template at marker positions and run
989/// each text chunk through `encode_fn`, then weave in the media
990/// placeholders via [`expand_media_placeholders`].
991///
992/// `slots` describes what each marker becomes — typically
993/// `MediaSlot::Image { count: vision.num_soft_tokens }` for an image
994/// and `MediaSlot::Audio { count: num_audio_frames }` for an audio
995/// clip. The slots are consumed in the same order the markers appear
996/// in `prompt`.
997///
998/// `encode_fn` is the caller's chosen tokenizer wrapper — e.g.
999/// `|s| rlx_qwen35::encode_prompt_auto(weights_path, tokenizer_path, s)`.
1000///
1001/// Returns the full token-id sequence ready to feed to the LM, plus
1002/// the marker positions in the original template (useful when the
1003/// caller needs to line up `MediaSlot` counts with actual media
1004/// inputs).
1005pub fn tokenize_with_media<F>(
1006    prompt: &str,
1007    slots: &[MediaSlot],
1008    cfg: &GemmaMultimodalConfig,
1009    mut encode_fn: F,
1010) -> Result<Vec<u32>>
1011where
1012    F: FnMut(&str) -> Result<Vec<u32>>,
1013{
1014    // Walk the prompt and split at IMAGE_MARKER / AUDIO_MARKER. Each
1015    // marker must consume one slot, in declaration order. We accept
1016    // either marker type — the caller's `slots` list determines what
1017    // actually gets inserted at each split.
1018    let mut text_chunks: Vec<Vec<u32>> = Vec::with_capacity(slots.len() + 1);
1019    let mut cursor = 0usize;
1020    let mut markers_seen = 0usize;
1021    let bytes = prompt.as_bytes();
1022    while cursor <= bytes.len() {
1023        let remainder = &prompt[cursor..];
1024        let next = next_media_marker(remainder);
1025        match next {
1026            Some((rel, marker)) => {
1027                let chunk = &remainder[..rel];
1028                text_chunks.push(encode_fn(chunk)?);
1029                cursor += rel + marker.len();
1030                markers_seen += 1;
1031            }
1032            None => {
1033                // Final tail — push the rest and exit.
1034                text_chunks.push(encode_fn(remainder)?);
1035                break;
1036            }
1037        }
1038    }
1039    if markers_seen != slots.len() {
1040        bail!(
1041            "prompt has {markers_seen} media markers but {} slot(s) supplied",
1042            slots.len(),
1043        );
1044    }
1045    expand_media_placeholders(&text_chunks, slots, cfg)
1046}
1047
1048/// Expand a prompt template containing literal `<image>` / `<audio>`
1049/// markers into a token id stream. `prefix_tokens` and
1050/// `suffix_tokens` are the tokenized text segments split at the
1051/// marker positions; `slots` describes what to insert between each
1052/// pair.
1053///
1054/// Returns the fused token sequence ready to feed to the LM.
1055///
1056/// `prefix_tokens` must have length `slots.len() + 1` —
1057/// i.e. one text chunk per marker boundary.
1058pub fn expand_media_placeholders(
1059    text_chunks: &[Vec<u32>],
1060    slots: &[MediaSlot],
1061    cfg: &GemmaMultimodalConfig,
1062) -> Result<Vec<u32>> {
1063    if text_chunks.len() != slots.len() + 1 {
1064        bail!(
1065            "text_chunks ({}) must equal slots ({}) + 1",
1066            text_chunks.len(),
1067            slots.len(),
1068        );
1069    }
1070    let mut out: Vec<u32> =
1071        Vec::with_capacity(text_chunks.iter().map(|c| c.len()).sum::<usize>() + slots.len() * 16);
1072    for (i, chunk) in text_chunks.iter().enumerate() {
1073        out.extend_from_slice(chunk);
1074        if i < slots.len() {
1075            match slots[i] {
1076                MediaSlot::Image { count } => {
1077                    let token = cfg.image_token_id.ok_or_else(|| {
1078                        anyhow!("image slot requested but image_token_id is unset")
1079                    })?;
1080                    if let Some(boi) = cfg.boi_token_id {
1081                        out.push(boi);
1082                    }
1083                    for _ in 0..count {
1084                        out.push(token);
1085                    }
1086                    if let Some(eoi) = cfg.eoi_token_id {
1087                        out.push(eoi);
1088                    }
1089                }
1090                MediaSlot::Audio { count } => {
1091                    let token = cfg.audio_token_id.ok_or_else(|| {
1092                        anyhow!("audio slot requested but audio_token_id is unset")
1093                    })?;
1094                    if let Some(boa) = cfg.boa_token_id {
1095                        out.push(boa);
1096                    }
1097                    for _ in 0..count {
1098                        out.push(token);
1099                    }
1100                    if let Some(eoa) = cfg.eoa_token_index {
1101                        out.push(eoa);
1102                    }
1103                }
1104                MediaSlot::Video { count } => {
1105                    let token = cfg.video_token_id.ok_or_else(|| {
1106                        anyhow!("video slot requested but video_token_id is unset")
1107                    })?;
1108                    if let Some(boi) = cfg.boi_token_id {
1109                        out.push(boi);
1110                    }
1111                    for _ in 0..count {
1112                        out.push(token);
1113                    }
1114                    if let Some(eoi) = cfg.eoi_token_id {
1115                        out.push(eoi);
1116                    }
1117                }
1118            }
1119        }
1120    }
1121    Ok(out)
1122}
1123
1124// ── Token-stream fusion (CPU-side glue) ──────────────────────────────
1125
1126/// Replace placeholder media-token rows in a CPU embedding sequence
1127/// with precomputed media-projection rows.
1128///
1129/// `text_embeds` is a `[seq, hidden]` row-major buffer; `token_ids`
1130/// is the matching `[seq]` id stream. Wherever `token_ids[i] ==
1131/// cfg.image_token_id`, the row `text_embeds[i*hidden..(i+1)*hidden]`
1132/// is overwritten with the next available row from
1133/// `image_embeds`. Audio/video tokens follow the same rule.
1134///
1135/// This is the runtime hand-off between the multimodal HIR fragments
1136/// above and the LM token-stream input. Designed to be called after
1137/// embedding lookup but before the first decoder block.
1138pub fn fuse_multimodal_embeddings(
1139    text_embeds: &mut [f32],
1140    token_ids: &[u32],
1141    hidden: usize,
1142    cfg: &GemmaMultimodalConfig,
1143    image_embeds: &[f32],
1144    audio_embeds: &[f32],
1145    video_embeds: &[f32],
1146) -> Result<()> {
1147    if text_embeds.len() != token_ids.len() * hidden {
1148        bail!(
1149            "text_embeds {} != tokens {} * hidden {}",
1150            text_embeds.len(),
1151            token_ids.len(),
1152            hidden,
1153        );
1154    }
1155    let mut img_cursor = 0usize;
1156    let mut aud_cursor = 0usize;
1157    let mut vid_cursor = 0usize;
1158    for (i, &tok) in token_ids.iter().enumerate() {
1159        let dst = &mut text_embeds[i * hidden..(i + 1) * hidden];
1160        if Some(tok) == cfg.image_token_id {
1161            let src = image_embeds
1162                .get(img_cursor * hidden..(img_cursor + 1) * hidden)
1163                .ok_or_else(|| {
1164                    anyhow!(
1165                        "image_embeds exhausted at token {i}: need {} rows, have {}",
1166                        img_cursor + 1,
1167                        image_embeds.len() / hidden,
1168                    )
1169                })?;
1170            dst.copy_from_slice(src);
1171            img_cursor += 1;
1172        } else if Some(tok) == cfg.video_token_id {
1173            let src = video_embeds
1174                .get(vid_cursor * hidden..(vid_cursor + 1) * hidden)
1175                .ok_or_else(|| {
1176                    anyhow!(
1177                        "video_embeds exhausted at token {i}: need {} rows, have {}",
1178                        vid_cursor + 1,
1179                        video_embeds.len() / hidden,
1180                    )
1181                })?;
1182            dst.copy_from_slice(src);
1183            vid_cursor += 1;
1184        } else if Some(tok) == cfg.audio_token_id {
1185            let src = audio_embeds
1186                .get(aud_cursor * hidden..(aud_cursor + 1) * hidden)
1187                .ok_or_else(|| {
1188                    anyhow!(
1189                        "audio_embeds exhausted at token {i}: need {} rows, have {}",
1190                        aud_cursor + 1,
1191                        audio_embeds.len() / hidden,
1192                    )
1193                })?;
1194            dst.copy_from_slice(src);
1195            aud_cursor += 1;
1196        }
1197    }
1198    Ok(())
1199}
1200
1201#[cfg(test)]
1202mod tests {
1203    use super::*;
1204
1205    const GEMMA_4_12B_FULL_CONFIG: &str = r#"{
1206      "model_type": "gemma4_unified",
1207      "audio_token_id": 258881,
1208      "image_token_id": 258880,
1209      "video_token_id": 258884,
1210      "boi_token_id": 255999,
1211      "eoi_token_id": 258882,
1212      "boa_token_id": 256000,
1213      "eoa_token_index": 258883,
1214      "audio_config": {
1215        "audio_embed_dim": 640,
1216        "audio_samples_per_token": 640,
1217        "hidden_size": 640,
1218        "output_proj_dims": 640,
1219        "rms_norm_eps": 1e-6
1220      },
1221      "vision_config": {
1222        "mm_embed_dim": 3840,
1223        "mm_posemb_size": 1120,
1224        "model_patch_size": 48,
1225        "num_soft_tokens": 280,
1226        "output_proj_dims": 3840,
1227        "patch_size": 16,
1228        "pooling_kernel_size": 3,
1229        "rms_norm_eps": 1e-6
1230      }
1231    }"#;
1232
1233    #[test]
1234    fn multimodal_config_parses_unified_layout() {
1235        let cfg = GemmaMultimodalConfig::parse_json(GEMMA_4_12B_FULL_CONFIG).unwrap();
1236        let vision = cfg.vision.as_ref().unwrap();
1237        let audio = cfg.audio.as_ref().unwrap();
1238        assert_eq!(vision.patch_size, 16);
1239        assert_eq!(vision.model_patch_size, 48);
1240        assert_eq!(vision.mm_embed_dim, 3840);
1241        assert_eq!(vision.num_soft_tokens, 280);
1242        assert_eq!(vision.output_proj_dims, 3840);
1243        assert_eq!(vision.pooling_kernel_size, 3);
1244        assert_eq!(audio.audio_samples_per_token, 640);
1245        assert_eq!(audio.audio_embed_dim, 640);
1246        assert_eq!(audio.output_proj_dims, 640);
1247        assert_eq!(cfg.image_token_id, Some(258_880));
1248        assert_eq!(cfg.audio_token_id, Some(258_881));
1249        assert_eq!(cfg.video_token_id, Some(258_884));
1250    }
1251
1252    #[test]
1253    fn fuse_replaces_only_placeholder_rows() {
1254        let cfg = GemmaMultimodalConfig {
1255            image_token_id: Some(100),
1256            audio_token_id: Some(200),
1257            ..Default::default()
1258        };
1259        let hidden = 4;
1260        let mut text = vec![
1261            // token 0 — plain text
1262            1.0, 1.0, 1.0, 1.0, //
1263            // token 1 — image placeholder
1264            0.0, 0.0, 0.0, 0.0, //
1265            // token 2 — audio placeholder
1266            0.0, 0.0, 0.0, 0.0, //
1267            // token 3 — plain text
1268            2.0, 2.0, 2.0, 2.0, //
1269        ];
1270        let ids = [42, 100, 200, 43];
1271        let img = vec![7.0, 7.0, 7.0, 7.0];
1272        let aud = vec![9.0, 9.0, 9.0, 9.0];
1273        fuse_multimodal_embeddings(&mut text, &ids, hidden, &cfg, &img, &aud, &[]).unwrap();
1274        assert_eq!(&text[0..4], &[1.0, 1.0, 1.0, 1.0]);
1275        assert_eq!(&text[4..8], &[7.0, 7.0, 7.0, 7.0]);
1276        assert_eq!(&text[8..12], &[9.0, 9.0, 9.0, 9.0]);
1277        assert_eq!(&text[12..16], &[2.0, 2.0, 2.0, 2.0]);
1278    }
1279
1280    #[test]
1281    fn fuse_errors_when_media_runs_out() {
1282        let cfg = GemmaMultimodalConfig {
1283            image_token_id: Some(100),
1284            ..Default::default()
1285        };
1286        let mut text = vec![0.0; 8];
1287        let ids = [100, 100];
1288        let img = vec![1.0; 4]; // only one image row available
1289        let err = fuse_multimodal_embeddings(&mut text, &ids, 4, &cfg, &img, &[], &[]).unwrap_err();
1290        assert!(err.to_string().contains("image_embeds exhausted"));
1291    }
1292
1293    #[test]
1294    fn empty_config_is_no_op() {
1295        let cfg = GemmaMultimodalConfig::default();
1296        let mut text = vec![1.0, 2.0, 3.0, 4.0];
1297        let ids = [10, 20];
1298        fuse_multimodal_embeddings(&mut text, &ids, 2, &cfg, &[], &[], &[]).unwrap();
1299        assert_eq!(text, vec![1.0, 2.0, 3.0, 4.0]);
1300    }
1301
1302    #[test]
1303    fn extract_image_patches_shapes_match_expected_grid() {
1304        // 4x4 image, patch_size=2 → 4 patches, each 2*2*3 = 12 floats.
1305        let rgb: Vec<u8> = (0..(4 * 4 * 3) as u8).collect();
1306        let out = extract_image_patches(&rgb, 4, 4, 2).unwrap();
1307        assert_eq!(out.len(), 4 * 12);
1308        // Top-left patch first 6 floats correspond to pixels (0,0) and (1,0).
1309        // rgb[0..3] = [0,1,2], rgb[3..6] = [3,4,5], normalized.
1310        assert!((out[0] - 0.0 / 255.0).abs() < 1e-6);
1311        assert!((out[1] - 1.0 / 255.0).abs() < 1e-6);
1312        assert!((out[2] - 2.0 / 255.0).abs() < 1e-6);
1313        assert!((out[3] - 3.0 / 255.0).abs() < 1e-6);
1314    }
1315
1316    #[test]
1317    fn extract_image_patches_truncates_partial_pixels() {
1318        // 5x5 image, patch_size=2 → grid clamps to 4x4 → 4 patches, last
1319        // row/col of pixels is discarded.
1320        let rgb = vec![0u8; 5 * 5 * 3];
1321        let out = extract_image_patches(&rgb, 5, 5, 2).unwrap();
1322        assert_eq!(out.len(), 4 * 12);
1323    }
1324
1325    #[test]
1326    fn extract_image_patches_rejects_size_mismatch() {
1327        let rgb = vec![0u8; 4 * 4 * 3 - 1];
1328        assert!(extract_image_patches(&rgb, 4, 4, 2).is_err());
1329    }
1330
1331    #[test]
1332    fn frame_audio_samples_pads_last_frame() {
1333        let samples = vec![1.0f32; 1500]; // 1500 / 640 = 2.34 → 3 frames
1334        let (out, n) = frame_audio_samples(&samples, 640).unwrap();
1335        assert_eq!(n, 3);
1336        assert_eq!(out.len(), 3 * 640);
1337        // Tail is padded with zeros.
1338        for &v in &out[1500..] {
1339            assert_eq!(v, 0.0);
1340        }
1341        // Head is the original samples.
1342        for &v in &out[..1500] {
1343            assert_eq!(v, 1.0);
1344        }
1345    }
1346
1347    #[test]
1348    fn frame_audio_samples_minimum_one_frame() {
1349        let (out, n) = frame_audio_samples(&[], 640).unwrap();
1350        assert_eq!(n, 1);
1351        assert_eq!(out.len(), 640);
1352    }
1353
1354    #[test]
1355    fn expand_media_placeholders_brackets_and_inlines_tokens() {
1356        let cfg = GemmaMultimodalConfig {
1357            image_token_id: Some(900),
1358            boi_token_id: Some(800),
1359            eoi_token_id: Some(801),
1360            audio_token_id: Some(950),
1361            boa_token_id: Some(850),
1362            eoa_token_index: Some(851),
1363            ..Default::default()
1364        };
1365        let chunks = vec![vec![1, 2], vec![3], vec![4, 5]];
1366        let slots = vec![MediaSlot::Image { count: 4 }, MediaSlot::Audio { count: 2 }];
1367        let out = expand_media_placeholders(&chunks, &slots, &cfg).unwrap();
1368        assert_eq!(
1369            out,
1370            vec![
1371                1, 2, /* boi */ 800, 900, 900, 900, 900, /* eoi */ 801, 3,
1372                /* boa */ 850, 950, 950, /* eoa */ 851, 4, 5
1373            ],
1374        );
1375    }
1376
1377    #[test]
1378    fn expand_media_placeholders_rejects_mismatched_chunks() {
1379        let cfg = GemmaMultimodalConfig {
1380            image_token_id: Some(900),
1381            ..Default::default()
1382        };
1383        let chunks = vec![vec![1]];
1384        let slots = vec![MediaSlot::Image { count: 4 }];
1385        assert!(expand_media_placeholders(&chunks, &slots, &cfg).is_err());
1386    }
1387
1388    #[test]
1389    fn standalone_projector_graphs_only_take_media_as_input() {
1390        // Vision: weights live as graph params, so only "patches" is
1391        // a runtime input.
1392        let v_cfg = GemmaVisionConfig::default();
1393        let g = build_vision_projection_graph(1, 16, &v_cfg).unwrap();
1394        assert_eq!(g.input_keys, vec!["patches".to_string()]);
1395
1396        // Audio: only "frames" is an input.
1397        let a_cfg = GemmaAudioConfig::default();
1398        let g = build_audio_projection_graph(1, 8, &a_cfg, 3840).unwrap();
1399        assert_eq!(g.input_keys, vec!["frames".to_string()]);
1400    }
1401
1402    #[test]
1403    fn parse_wav_decodes_minimal_pcm16_mono() {
1404        // Synthesize a tiny 16-bit PCM WAV: 4 samples @ 16 kHz mono.
1405        let samples_i16: [i16; 4] = [0, 16_384, -16_384, 32_767];
1406        let mut bytes = Vec::new();
1407        // RIFF header
1408        bytes.extend_from_slice(b"RIFF");
1409        let total_size = 4 + (8 + 16) + (8 + samples_i16.len() * 2); // WAVE + fmt + data
1410        bytes.extend_from_slice(&(total_size as u32).to_le_bytes());
1411        bytes.extend_from_slice(b"WAVE");
1412        // fmt chunk
1413        bytes.extend_from_slice(b"fmt ");
1414        bytes.extend_from_slice(&16u32.to_le_bytes());
1415        bytes.extend_from_slice(&1u16.to_le_bytes()); // PCM
1416        bytes.extend_from_slice(&1u16.to_le_bytes()); // mono
1417        bytes.extend_from_slice(&16_000u32.to_le_bytes()); // sample rate
1418        bytes.extend_from_slice(&32_000u32.to_le_bytes()); // byte rate
1419        bytes.extend_from_slice(&2u16.to_le_bytes()); // block align
1420        bytes.extend_from_slice(&16u16.to_le_bytes()); // bits per sample
1421        // data chunk
1422        bytes.extend_from_slice(b"data");
1423        bytes.extend_from_slice(&((samples_i16.len() * 2) as u32).to_le_bytes());
1424        for s in samples_i16 {
1425            bytes.extend_from_slice(&s.to_le_bytes());
1426        }
1427        let pcm = parse_wav_16khz_mono(&bytes).unwrap();
1428        assert_eq!(pcm.len(), 4);
1429        assert!((pcm[0] - 0.0).abs() < 1e-4);
1430        assert!((pcm[1] - 0.5).abs() < 1e-3);
1431        assert!((pcm[2] - (-0.5)).abs() < 1e-3);
1432        assert!((pcm[3] - 1.0).abs() < 1e-3);
1433    }
1434
1435    #[test]
1436    fn resample_linear_preserves_constants() {
1437        // A DC signal resampled at any ratio should stay DC.
1438        let src = vec![0.7f32; 1000];
1439        let out = resample_linear(&src, 48_000, 16_000);
1440        // Output length ~= 1000 * 16k/48k = 333 samples.
1441        assert!((out.len() as i32 - 333).abs() <= 1);
1442        for &v in &out {
1443            assert!((v - 0.7).abs() < 1e-5);
1444        }
1445    }
1446
1447    #[test]
1448    fn tokenize_with_media_splits_and_expands() {
1449        let cfg = GemmaMultimodalConfig {
1450            image_token_id: Some(900),
1451            boi_token_id: Some(800),
1452            eoi_token_id: Some(801),
1453            audio_token_id: Some(950),
1454            boa_token_id: Some(850),
1455            eoa_token_index: Some(851),
1456            ..Default::default()
1457        };
1458        // Stub encoder — each text chunk turns into a vec of its
1459        // bytes' ASCII codes so we can pattern-match.
1460        let encode = |s: &str| -> Result<Vec<u32>> { Ok(s.bytes().map(|b| b as u32).collect()) };
1461        let prompt = "hi <image> see <audio> bye";
1462        let slots = vec![MediaSlot::Image { count: 2 }, MediaSlot::Audio { count: 1 }];
1463        let out = tokenize_with_media(prompt, &slots, &cfg, encode).unwrap();
1464        // Expected: "hi " + [boi, 900, 900, eoi] + " see " + [boa, 950, eoa] + " bye"
1465        let mut expected: Vec<u32> = b"hi ".iter().map(|b| *b as u32).collect();
1466        expected.extend([800, 900, 900, 801]);
1467        expected.extend(b" see ".iter().map(|b| *b as u32));
1468        expected.extend([850, 950, 851]);
1469        expected.extend(b" bye".iter().map(|b| *b as u32));
1470        assert_eq!(out, expected);
1471    }
1472
1473    #[test]
1474    fn tokenize_with_media_rejects_slot_marker_mismatch() {
1475        let cfg = GemmaMultimodalConfig {
1476            image_token_id: Some(900),
1477            ..Default::default()
1478        };
1479        let encode = |_: &str| -> Result<Vec<u32>> { Ok(vec![]) };
1480        // Two markers but only one slot.
1481        let err = tokenize_with_media(
1482            "a <image> b <image> c",
1483            &[MediaSlot::Image { count: 1 }],
1484            &cfg,
1485            encode,
1486        )
1487        .unwrap_err();
1488        assert!(err.to_string().contains("media markers"));
1489    }
1490}