Skip to main content

rlx_gemma/
multimodal_runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16use anyhow::{Context, Result, anyhow, bail};
17use rlx_runtime::{CompiledGraph, Device, Session};
18use safetensors::SafeTensors;
19use std::collections::HashMap;
20use std::path::{Path, PathBuf};
21
22use crate::multimodal::{
23    AUDIO_MARKER, AUDIO_MARKER_HF, GemmaAudioConfig, GemmaMultimodalConfig, GemmaVisionConfig,
24    IMAGE_MARKER, IMAGE_MARKER_HF, MediaSlot, VIDEO_MARKER, VIDEO_MARKER_HF,
25    build_audio_projection_graph, build_vision_projection_graph, frame_audio_samples,
26    fuse_multimodal_embeddings, load_image_patches, tokenize_with_media,
27};
28use crate::unified_preprocess::{
29    compute_num_soft_tokens_from_size, factorized_pos_bias, load_unified_image,
30    prepare_unified_audio_samples, strip_valid_vision_rows, unified_audio_token_count,
31};
32use crate::unified_projector::{
33    build_unified_audio_graph, build_unified_vision_graph, is_unified_vision_weights,
34};
35
36/// Which HF projector weight layout is loaded.
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
38pub enum ProjectorLayout {
39    /// llama.cpp-style `vision_tower.*` + learned soft-token pool.
40    #[default]
41    LegacyPool,
42    /// Native `google/gemma-4-12B-it` unified embedders.
43    Unified,
44}
45
46/// One compiled projector — wraps a `CompiledGraph` + the static
47/// dimensions it was specialised for.
48struct VisionStage {
49    compiled: CompiledGraph,
50    num_patches: usize,
51    output_dim: usize,
52    unified: bool,
53}
54
55impl VisionStage {
56    fn output_shape(&self) -> (usize, usize) {
57        (self.num_patches, self.output_dim)
58    }
59}
60
61struct AudioStage {
62    compiled: CompiledGraph,
63    num_frames: usize,
64    output_dim: usize,
65    unified: bool,
66}
67
68impl AudioStage {
69    fn output_shape(&self) -> (usize, usize) {
70        (self.num_frames, self.output_dim)
71    }
72}
73
74/// End-to-end multimodal pipeline for Gemma 4 unified models.
75///
76/// Holds the compiled vision/audio projector graphs (constructed
77/// once) and runs them on demand to produce hidden-aligned soft
78/// tokens. The LM-side fusion is exposed as
79/// [`GemmaMultimodalRunner::fuse_text_and_media`].
80pub struct GemmaMultimodalRunner {
81    cfg: GemmaMultimodalConfig,
82    lm_hidden: usize,
83    device: Device,
84    layout: ProjectorLayout,
85    max_soft_tokens: usize,
86    max_audio_tokens: usize,
87    vision: Option<VisionStage>,
88    audio: Option<AudioStage>,
89}
90
91impl GemmaMultimodalRunner {
92    /// Construct a runner for a fixed image-grid and audio-frame
93    /// budget. Both `num_patches` and `num_frames` are graph-time
94    /// constants — call [`Self::reconfigure`] if you need different
95    /// dimensions later.
96    pub fn new(
97        cfg: GemmaMultimodalConfig,
98        lm_hidden: usize,
99        device: Device,
100        max_soft_tokens: Option<usize>,
101        max_audio_tokens: Option<usize>,
102    ) -> Result<Self> {
103        let max_soft = max_soft_tokens.unwrap_or_else(|| {
104            cfg.vision
105                .as_ref()
106                .map(|v| v.num_soft_tokens)
107                .unwrap_or(280)
108        });
109        Ok(Self {
110            cfg,
111            lm_hidden,
112            device,
113            layout: ProjectorLayout::LegacyPool,
114            max_soft_tokens: max_soft,
115            max_audio_tokens: max_audio_tokens.unwrap_or(750),
116            vision: None,
117            audio: None,
118        })
119    }
120
121    /// Current vision projector dimensions, when one is compiled.
122    /// `(num_patches, output_dim)` — output_dim equals the LM hidden
123    /// size for the Gemma 4 12B reference.
124    pub fn vision_output_shape(&self) -> Option<(usize, usize)> {
125        self.vision.as_ref().map(|s| s.output_shape())
126    }
127
128    /// Current audio projector dimensions, when one is compiled.
129    /// `(num_frames, output_dim)`.
130    pub fn audio_output_shape(&self) -> Option<(usize, usize)> {
131        self.audio.as_ref().map(|s| s.output_shape())
132    }
133
134    pub fn config(&self) -> &GemmaMultimodalConfig {
135        &self.cfg
136    }
137
138    pub fn lm_hidden(&self) -> usize {
139        self.lm_hidden
140    }
141
142    pub fn layout(&self) -> ProjectorLayout {
143        self.layout
144    }
145
146    pub fn is_unified(&self) -> bool {
147        self.layout == ProjectorLayout::Unified
148    }
149
150    fn sync_layout(&mut self, weights: &MultimodalWeights) {
151        self.layout = weights.layout();
152    }
153
154    fn vision_cfg(&self) -> Result<&GemmaVisionConfig> {
155        self.cfg
156            .vision
157            .as_ref()
158            .ok_or_else(|| anyhow!("vision config missing — model is text/audio only"))
159    }
160
161    fn audio_cfg(&self) -> Result<&GemmaAudioConfig> {
162        self.cfg
163            .audio
164            .as_ref()
165            .ok_or_else(|| anyhow!("audio config missing — model is text/vision only"))
166    }
167
168    /// Compile (or recompile) the vision projector graph for a
169    /// `num_patches` budget. Weights are NOT loaded yet — call
170    /// [`Self::load_vision_weights`] before running.
171    pub fn compile_vision(&mut self, num_patches: usize) -> Result<()> {
172        let vcfg = self.vision_cfg()?.clone();
173        let unified = self.layout == ProjectorLayout::Unified;
174        let g = if unified {
175            build_unified_vision_graph(num_patches, &vcfg)?
176        } else {
177            build_vision_projection_graph(1, num_patches, &vcfg)?
178        };
179        let session = Session::new(self.device);
180        let compiled = session
181            .compile_hir(g.hir)
182            .map_err(|e| anyhow!("vision projector lower failed: {e:?}"))?;
183        self.vision = Some(VisionStage {
184            compiled,
185            num_patches,
186            output_dim: if unified {
187                self.lm_hidden
188            } else {
189                vcfg.output_proj_dims
190            },
191            unified,
192        });
193        Ok(())
194    }
195
196    pub fn compile_audio(&mut self, num_frames: usize) -> Result<()> {
197        let acfg = self.audio_cfg()?.clone();
198        let unified = self.layout == ProjectorLayout::Unified;
199        let g = if unified {
200            build_unified_audio_graph(num_frames, &acfg, self.lm_hidden)?
201        } else {
202            build_audio_projection_graph(1, num_frames, &acfg, self.lm_hidden)?
203        };
204        let session = Session::new(self.device);
205        let compiled = session
206            .compile_hir(g.hir)
207            .map_err(|e| anyhow!("audio projector lower failed: {e:?}"))?;
208        self.audio = Some(AudioStage {
209            compiled,
210            num_frames,
211            output_dim: self.lm_hidden,
212            unified,
213        });
214        Ok(())
215    }
216
217    /// Reconfigure the vision projector for a new patch count. Calls
218    /// [`Self::compile_vision`] followed by
219    /// [`Self::load_vision_weights`].
220    pub fn reconfigure(
221        &mut self,
222        num_patches: Option<usize>,
223        num_frames: Option<usize>,
224        weights: &MultimodalWeights,
225    ) -> Result<()> {
226        if let Some(n) = num_patches {
227            self.compile_vision(n)?;
228            self.load_vision_weights(weights)?;
229        }
230        if let Some(n) = num_frames {
231            self.compile_audio(n)?;
232            self.load_audio_weights(weights)?;
233        }
234        Ok(())
235    }
236
237    /// Bind vision-projector weights into the compiled graph.
238    pub fn load_vision_weights(&mut self, weights: &MultimodalWeights) -> Result<()> {
239        self.sync_layout(weights);
240        let stage = self
241            .vision
242            .as_mut()
243            .ok_or_else(|| anyhow!("vision projector not compiled — call compile_vision first"))?;
244        if stage.unified {
245            let vcfg = self.cfg.vision.as_ref().unwrap();
246            let d = vcfg.mm_embed_dim;
247            for key in [
248                "model.vision_embedder.patch_ln1.weight",
249                "model.vision_embedder.patch_ln1.bias",
250                "model.vision_embedder.patch_dense.weight",
251                "model.vision_embedder.patch_dense.bias",
252                "model.vision_embedder.patch_ln2.weight",
253                "model.vision_embedder.patch_ln2.bias",
254                "model.vision_embedder.pos_norm.weight",
255                "model.vision_embedder.pos_norm.bias",
256                "model.embed_vision.embedding_projection.weight",
257            ] {
258                let data = weights.get_linear(key)?;
259                stage.compiled.set_param(key, &data);
260            }
261            stage.compiled.set_param("unified.ones", &vec![1.0f32; d]);
262            stage
263                .compiled
264                .set_param("unified.zero_beta", &vec![0.0f32; d]);
265        } else {
266            for key in [
267                "vision_tower.embed.weight",
268                "vision_tower.pos_embed.weight",
269                "vision_tower.norm.weight",
270                "vision_tower.soft_token.weight",
271                "vision_tower.lm_proj.weight",
272            ] {
273                let data = weights.get(key)?;
274                stage.compiled.set_param(key, data);
275            }
276            let vcfg = self.cfg.vision.as_ref().unwrap();
277            stage
278                .compiled
279                .set_param("vision_tower.ones", &vec![1.0f32; vcfg.mm_embed_dim]);
280            stage
281                .compiled
282                .set_param("vision_tower.zero_beta", &vec![0.0f32; vcfg.mm_embed_dim]);
283        }
284        Ok(())
285    }
286
287    /// Bind audio-projector weights into the compiled graph.
288    pub fn load_audio_weights(&mut self, weights: &MultimodalWeights) -> Result<()> {
289        self.sync_layout(weights);
290        let stage = self
291            .audio
292            .as_mut()
293            .ok_or_else(|| anyhow!("audio projector not compiled — call compile_audio first"))?;
294        if stage.unified {
295            let acfg = self.cfg.audio.as_ref().unwrap();
296            let d = acfg.audio_embed_dim;
297            let data = weights.get_linear("model.embed_audio.embedding_projection.weight")?;
298            stage
299                .compiled
300                .set_param("model.embed_audio.embedding_projection.weight", &data);
301            stage
302                .compiled
303                .set_param("unified.audio.ones", &vec![1.0f32; d]);
304            stage
305                .compiled
306                .set_param("unified.audio.zero_beta", &vec![0.0f32; d]);
307        } else {
308            for key in [
309                "audio_tower.embed.weight",
310                "audio_tower.norm.weight",
311                "audio_tower.lm_proj.weight",
312            ] {
313                let data = weights.get(key)?;
314                stage.compiled.set_param(key, data);
315            }
316            let acfg = self.cfg.audio.as_ref().unwrap();
317            stage
318                .compiled
319                .set_param("audio_tower.ones", &vec![1.0f32; acfg.audio_embed_dim]);
320            stage
321                .compiled
322                .set_param("audio_tower.zero_beta", &vec![0.0f32; acfg.audio_embed_dim]);
323        }
324        Ok(())
325    }
326
327    /// Project a single image (already preprocessed to patches) into LM hidden rows.
328    pub fn project_image_patches(&mut self, patches: &[f32]) -> Result<Vec<f32>> {
329        self.project_image_with_pos(patches, None)
330    }
331
332    /// Unified vision path: `patches` + optional CPU `pos_bias`.
333    pub fn project_image_with_pos(
334        &mut self,
335        patches: &[f32],
336        pos_bias: Option<&[f32]>,
337    ) -> Result<Vec<f32>> {
338        let stage = self
339            .vision
340            .as_mut()
341            .ok_or_else(|| anyhow!("vision projector not compiled"))?;
342        let outs = if stage.unified {
343            let pos = pos_bias.ok_or_else(|| anyhow!("unified vision requires pos_bias"))?;
344            stage
345                .compiled
346                .run(&[("patches", patches), ("pos_bias", pos)])
347        } else {
348            stage.compiled.run(&[("patches", patches)])
349        };
350        outs.into_iter()
351            .next()
352            .ok_or_else(|| anyhow!("vision projector returned no outputs"))
353    }
354
355    /// Convenience: load a JPEG/PNG, resize to fit the configured
356    /// patch budget, run the projector. The vision graph is
357    /// recompiled when the produced patch count differs from the
358    /// current setting (so callers pinning to a single image size
359    /// pay zero overhead).
360    pub fn project_image_file(
361        &mut self,
362        path: impl AsRef<Path>,
363        weights: &MultimodalWeights,
364        max_side_patches: usize,
365    ) -> Result<Vec<f32>> {
366        self.sync_layout(weights);
367        if self.layout == ProjectorLayout::Unified {
368            let vcfg = self.vision_cfg()?.clone();
369            let img = load_unified_image(
370                path.as_ref(),
371                vcfg.patch_size,
372                vcfg.pooling_kernel_size,
373                self.max_soft_tokens,
374            )?;
375            let num_slots = self.max_soft_tokens;
376            let need_recompile = match &self.vision {
377                Some(s) => s.num_patches != num_slots || !s.unified,
378                None => true,
379            };
380            if need_recompile {
381                self.compile_vision(num_slots)?;
382                self.load_vision_weights(weights)?;
383            }
384            let pos_table = weights.get("model.vision_embedder.pos_embedding")?;
385            let pos_bias = factorized_pos_bias(
386                pos_table,
387                vcfg.mm_posemb_size,
388                vcfg.mm_embed_dim,
389                &img.positions,
390            );
391            let projected = self.project_image_with_pos(&img.patches, Some(&pos_bias))?;
392            return Ok(strip_valid_vision_rows(
393                &projected,
394                &img.positions,
395                self.lm_hidden,
396            ));
397        }
398
399        let vcfg = self.vision_cfg()?.clone();
400        let (patches, grid_h, grid_w) =
401            load_image_patches(path, vcfg.patch_size, max_side_patches)?;
402        let num_patches = grid_h * grid_w;
403        let need_recompile = match &self.vision {
404            Some(s) => s.num_patches != num_patches || s.unified,
405            None => true,
406        };
407        if need_recompile {
408            self.compile_vision(num_patches)?;
409            self.load_vision_weights(weights)?;
410        }
411        self.project_image_patches(&patches)
412    }
413
414    /// Project an audio waveform (already at 16 kHz mono f32) into
415    /// `[num_frames, lm_hidden]` audio soft tokens.
416    /// Project one video frame (unified layout, 70 soft tokens by default).
417    pub fn project_video_frame(
418        &mut self,
419        path: impl AsRef<Path>,
420        weights: &MultimodalWeights,
421    ) -> Result<Vec<f32>> {
422        let prev = self.max_soft_tokens;
423        self.max_soft_tokens = 70;
424        let out = self.project_image_file(path, weights, 32);
425        self.max_soft_tokens = prev;
426        out
427    }
428
429    /// Dynamic soft-token count for an on-disk image (unified layout).
430    pub fn image_soft_token_count(&self, path: impl AsRef<Path>) -> Result<usize> {
431        let vcfg = self.vision_cfg()?.clone();
432        let img =
433            image::open(path.as_ref()).map_err(|e| anyhow!("decode {:?}: {e}", path.as_ref()))?;
434        let (w, h) = (img.width() as usize, img.height() as usize);
435        compute_num_soft_tokens_from_size(
436            h,
437            w,
438            vcfg.patch_size,
439            vcfg.pooling_kernel_size,
440            self.max_soft_tokens,
441        )
442    }
443
444    pub fn project_audio_samples(
445        &mut self,
446        samples: &[f32],
447        weights: &MultimodalWeights,
448    ) -> Result<Vec<f32>> {
449        self.sync_layout(weights);
450        let acfg = self.audio_cfg()?.clone();
451        let prepared = if self.layout == ProjectorLayout::Unified {
452            prepare_unified_audio_samples(
453                samples,
454                acfg.audio_samples_per_token,
455                self.max_audio_tokens,
456            )
457        } else {
458            samples.to_vec()
459        };
460        let (frames, num_frames) = frame_audio_samples(&prepared, acfg.audio_samples_per_token)?;
461        let effective_frames = if self.layout == ProjectorLayout::Unified {
462            unified_audio_token_count(
463                samples
464                    .len()
465                    .min(crate::unified_preprocess::MAX_AUDIO_SAMPLES),
466                acfg.audio_samples_per_token,
467                self.max_audio_tokens,
468            )
469        } else {
470            num_frames
471        };
472        let need_recompile = match &self.audio {
473            Some(s) => {
474                s.num_frames != num_frames || s.unified != (self.layout == ProjectorLayout::Unified)
475            }
476            None => true,
477        };
478        if need_recompile {
479            self.compile_audio(num_frames)?;
480            self.load_audio_weights(weights)?;
481        }
482        let stage = self
483            .audio
484            .as_mut()
485            .ok_or_else(|| anyhow!("audio projector not compiled"))?;
486        let outs = stage.compiled.run(&[("frames", &frames[..])]);
487        let projected = outs
488            .into_iter()
489            .next()
490            .ok_or_else(|| anyhow!("audio projector returned no outputs"))?;
491        if self.layout == ProjectorLayout::Unified {
492            Ok(projected[..effective_frames * self.lm_hidden].to_vec())
493        } else {
494            Ok(projected)
495        }
496    }
497
498    /// Convenience: load and project a WAV file.
499    pub fn project_audio_file(
500        &mut self,
501        path: impl AsRef<Path>,
502        weights: &MultimodalWeights,
503    ) -> Result<Vec<f32>> {
504        self.sync_layout(weights);
505        let samples = crate::multimodal::load_wav_mono_16khz(path)?;
506        if self.audio.is_none() {
507            self.compile_audio(
508                samples
509                    .len()
510                    .div_ceil(self.audio_cfg()?.audio_samples_per_token),
511            )?;
512            self.load_audio_weights(weights)?;
513        }
514        let was_present = self.audio.is_some();
515        let result = self.project_audio_samples(&samples, weights);
516        if !was_present {
517            self.load_audio_weights(weights)?;
518        }
519        result
520    }
521
522    pub fn fuse_text_and_media(
523        &self,
524        text_embeds: &mut [f32],
525        token_ids: &[u32],
526        image_embeds: &[f32],
527        audio_embeds: &[f32],
528        video_embeds: &[f32],
529    ) -> Result<()> {
530        fuse_multimodal_embeddings(
531            text_embeds,
532            token_ids,
533            self.lm_hidden,
534            &self.cfg,
535            image_embeds,
536            audio_embeds,
537            video_embeds,
538        )
539    }
540
541    /// Tokenize a multimodal prompt with per-media dynamic placeholder counts.
542    pub fn tokenize_prompt<F>(
543        &self,
544        prompt: &str,
545        image_soft_counts: &[usize],
546        audio_sample_lengths: &[usize],
547        video_soft_counts: &[usize],
548        encode_fn: F,
549    ) -> Result<Vec<u32>>
550    where
551        F: FnMut(&str) -> Result<Vec<u32>>,
552    {
553        let audio_per = self
554            .cfg
555            .audio
556            .as_ref()
557            .map(|a| a.audio_samples_per_token)
558            .unwrap_or(640);
559
560        let mut slots: Vec<MediaSlot> = Vec::new();
561        let mut img_idx = 0usize;
562        let mut aud_idx = 0usize;
563        let mut vid_idx = 0usize;
564        let mut cursor = 0usize;
565        while cursor <= prompt.len() {
566            let remainder = &prompt[cursor..];
567            let next = find_next_marker(remainder);
568            match next {
569                Some((off, kind, marker_len)) => {
570                    match kind {
571                        MarkerKind::Image => {
572                            let count = *image_soft_counts.get(img_idx).ok_or_else(|| {
573                                anyhow!(
574                                    "not enough image_soft_counts for marker at offset {cursor}"
575                                )
576                            })?;
577                            slots.push(MediaSlot::Image { count });
578                            img_idx += 1;
579                        }
580                        MarkerKind::Audio => {
581                            let n = *audio_sample_lengths.get(aud_idx).ok_or_else(|| {
582                                anyhow!(
583                                    "not enough audio_sample_lengths for marker at offset {cursor}"
584                                )
585                            })?;
586                            let count = if audio_per == 640 {
587                                unified_audio_token_count(
588                                    n.min(crate::unified_preprocess::MAX_AUDIO_SAMPLES),
589                                    audio_per,
590                                    self.max_audio_tokens,
591                                )
592                            } else {
593                                n.div_ceil(audio_per).max(1)
594                            };
595                            slots.push(MediaSlot::Audio { count });
596                            aud_idx += 1;
597                        }
598                        MarkerKind::Video => {
599                            let count = *video_soft_counts.get(vid_idx).ok_or_else(|| {
600                                anyhow!(
601                                    "not enough video_soft_counts for marker at offset {cursor}"
602                                )
603                            })?;
604                            slots.push(MediaSlot::Video { count });
605                            vid_idx += 1;
606                        }
607                    }
608                    cursor += off + marker_len;
609                }
610                None => break,
611            }
612        }
613        if img_idx != image_soft_counts.len() {
614            bail!(
615                "prompt has {img_idx} image markers but {} image_soft_counts supplied",
616                image_soft_counts.len()
617            );
618        }
619        if aud_idx != audio_sample_lengths.len() {
620            bail!(
621                "prompt has {aud_idx} audio markers but {} audio lengths supplied",
622                audio_sample_lengths.len()
623            );
624        }
625        if vid_idx != video_soft_counts.len() {
626            bail!(
627                "prompt has {vid_idx} video markers but {} video_soft_counts supplied",
628                video_soft_counts.len()
629            );
630        }
631        tokenize_with_media(prompt, &slots, &self.cfg, encode_fn)
632    }
633}
634
635#[derive(Clone, Copy)]
636enum MarkerKind {
637    Image,
638    Audio,
639    Video,
640}
641
642fn find_next_marker(prompt: &str) -> Option<(usize, MarkerKind, usize)> {
643    let candidates: &[(&str, MarkerKind)] = &[
644        (IMAGE_MARKER_HF, MarkerKind::Image),
645        (IMAGE_MARKER, MarkerKind::Image),
646        (AUDIO_MARKER_HF, MarkerKind::Audio),
647        (AUDIO_MARKER, MarkerKind::Audio),
648        (VIDEO_MARKER_HF, MarkerKind::Video),
649        (VIDEO_MARKER, MarkerKind::Video),
650    ];
651    let mut best: Option<(usize, MarkerKind, usize)> = None;
652    for &(m, kind) in candidates {
653        if let Some(i) = prompt.find(m) {
654            if best.map(|(bi, _, _)| i < bi).unwrap_or(true) {
655                best = Some((i, kind, m.len()));
656            }
657        }
658    }
659    best
660}
661
662// ── Safetensors weights loader ──────────────────────────────────
663
664/// In-memory weight bundle for the multimodal projectors. Owns the
665/// raw f32 data; lookup-by-key in `get()` returns a borrowed slice
666/// that's lifetime-tied to the bundle.
667pub struct MultimodalWeights {
668    data: HashMap<String, Vec<f32>>,
669    /// Original file path (just for diagnostics).
670    pub source: Option<PathBuf>,
671}
672
673impl MultimodalWeights {
674    pub fn empty() -> Self {
675        Self {
676            data: HashMap::new(),
677            source: None,
678        }
679    }
680
681    /// Load vision/audio projector tensors from a safetensors file.
682    /// Supports legacy `vision_tower.*` / `audio_tower.*` and unified
683    /// `model.vision_embedder.*` / `model.embed_*` keys.
684    pub fn from_safetensors(path: impl AsRef<Path>) -> Result<Self> {
685        let path = path.as_ref();
686        let bytes = std::fs::read(path).with_context(|| format!("read {path:?}"))?;
687        let st = SafeTensors::deserialize(&bytes)
688            .map_err(|e| anyhow!("parse safetensors {path:?}: {e}"))?;
689        let mut data = HashMap::new();
690        for (name, view) in st.tensors() {
691            if !is_multimodal_tensor_key(&name) {
692                continue;
693            }
694            let shape: Vec<usize> = view.shape().to_vec();
695            let mut f32_data = tensor_to_f32(&view).with_context(|| format!("decode {name}"))?;
696            if should_transpose_hf_linear(&name, &shape) {
697                f32_data = transpose_2d(&f32_data, shape[0], shape[1]);
698            }
699            data.insert(name, f32_data);
700        }
701        Ok(Self {
702            data,
703            source: Some(path.to_path_buf()),
704        })
705    }
706
707    /// Load projector weights from a llama.cpp-style **`mmproj.gguf`**
708    /// companion file. llama.cpp ships multimodal projector weights
709    /// in a separate GGUF next to the LM `model.gguf` — this loader
710    /// drains every tensor whose name starts with `vision_tower.` /
711    /// `audio_tower.` and dequantizes it to F32.
712    ///
713    /// Tensors with other prefixes are ignored. The dequantization
714    /// path supports every K-quant scheme `rlx_gguf` decodes
715    /// (`Q8_0`, `Q4_0`, `Q4_K_M`, `Q5_K_S`, `Q6_K`, F16, BF16, F32).
716    pub fn from_mmproj_gguf(path: impl AsRef<Path>) -> Result<Self> {
717        let path = path.as_ref();
718        let file = rlx_gguf::GgufFile::from_path(path)
719            .with_context(|| format!("open mmproj GGUF {path:?}"))?;
720        let mut data = HashMap::new();
721        let keys: Vec<String> = file.keys().map(str::to_string).collect();
722        for name in keys {
723            if !is_multimodal_tensor_key(&name) {
724                continue;
725            }
726            let (decoded, shape) = file
727                .dequant_f32(&name)
728                .with_context(|| format!("dequant `{name}`"))?;
729            let mut f32_data = decoded;
730            if shape.len() == 2 && should_transpose_hf_linear(&name, &shape) {
731                f32_data = transpose_2d(&f32_data, shape[0], shape[1]);
732            }
733            data.insert(name, f32_data);
734        }
735        Ok(Self {
736            data,
737            source: Some(path.to_path_buf()),
738        })
739    }
740
741    /// Insert a custom-source tensor (mainly for tests + callers that
742    /// pull weights from a quantized GGUF).
743    pub fn insert(&mut self, key: impl Into<String>, data: Vec<f32>) {
744        self.data.insert(key.into(), data);
745    }
746
747    pub fn get(&self, key: &str) -> Result<&[f32]> {
748        self.data
749            .get(key)
750            .map(|v| v.as_slice())
751            .ok_or_else(|| anyhow!("multimodal weight missing: `{key}`"))
752    }
753
754    /// Matmul-ready linear weight (HF `out×in` already transposed when loaded).
755    pub fn get_linear(&self, key: &str) -> Result<Vec<f32>> {
756        Ok(self.get(key)?.to_vec())
757    }
758
759    pub fn layout(&self) -> ProjectorLayout {
760        if is_unified_vision_weights(self.keys()) {
761            ProjectorLayout::Unified
762        } else {
763            ProjectorLayout::LegacyPool
764        }
765    }
766
767    pub fn keys(&self) -> impl Iterator<Item = &str> {
768        self.data.keys().map(|s| s.as_str())
769    }
770}
771
772fn is_multimodal_tensor_key(name: &str) -> bool {
773    name.starts_with("vision_tower.")
774        || name.starts_with("audio_tower.")
775        || name.starts_with("model.vision_embedder.")
776        || name.starts_with("model.embed_vision.")
777        || name.starts_with("model.embed_audio.")
778}
779
780fn should_transpose_hf_linear(name: &str, shape: &[usize]) -> bool {
781    shape.len() == 2
782        && (name.contains("patch_dense.weight") || name.contains("embedding_projection.weight"))
783}
784
785fn transpose_2d(data: &[f32], rows: usize, cols: usize) -> Vec<f32> {
786    let mut out = vec![0f32; rows * cols];
787    for r in 0..rows {
788        for c in 0..cols {
789            out[c * rows + r] = data[r * cols + c];
790        }
791    }
792    out
793}
794
795fn tensor_to_f32(view: &safetensors::tensor::TensorView<'_>) -> Result<Vec<f32>> {
796    use safetensors::tensor::Dtype;
797    match view.dtype() {
798        Dtype::F32 => {
799            let raw = view.data();
800            if raw.len() % 4 != 0 {
801                bail!("F32 tensor data length not multiple of 4");
802            }
803            let mut out = Vec::with_capacity(raw.len() / 4);
804            for ch in raw.chunks_exact(4) {
805                out.push(f32::from_le_bytes([ch[0], ch[1], ch[2], ch[3]]));
806            }
807            Ok(out)
808        }
809        Dtype::F16 => {
810            let raw = view.data();
811            if raw.len() % 2 != 0 {
812                bail!("F16 tensor data length not multiple of 2");
813            }
814            let mut out = Vec::with_capacity(raw.len() / 2);
815            for ch in raw.chunks_exact(2) {
816                let bits = u16::from_le_bytes([ch[0], ch[1]]);
817                out.push(f16_to_f32(bits));
818            }
819            Ok(out)
820        }
821        Dtype::BF16 => {
822            let raw = view.data();
823            if raw.len() % 2 != 0 {
824                bail!("BF16 tensor data length not multiple of 2");
825            }
826            let mut out = Vec::with_capacity(raw.len() / 2);
827            for ch in raw.chunks_exact(2) {
828                // bf16 = top 16 bits of an f32.
829                let bits = u32::from(u16::from_le_bytes([ch[0], ch[1]])) << 16;
830                out.push(f32::from_bits(bits));
831            }
832            Ok(out)
833        }
834        other => bail!("unsupported tensor dtype for multimodal weight: {other:?}"),
835    }
836}
837
838fn f16_to_f32(bits: u16) -> f32 {
839    let sign = (bits >> 15) & 0x1;
840    let exp = (bits >> 10) & 0x1f;
841    let mant = bits & 0x3ff;
842    let f32_bits = if exp == 0 {
843        if mant == 0 {
844            (sign as u32) << 31
845        } else {
846            // Subnormal — normalize.
847            let mut m = mant as u32;
848            let mut e: i32 = -14;
849            while (m & 0x400) == 0 {
850                m <<= 1;
851                e -= 1;
852            }
853            m &= 0x3ff;
854            ((sign as u32) << 31) | (((e + 127) as u32) << 23) | (m << 13)
855        }
856    } else if exp == 0x1f {
857        ((sign as u32) << 31) | (0xff << 23) | ((mant as u32) << 13)
858    } else {
859        ((sign as u32) << 31) | (((exp as i32 + 127 - 15) as u32) << 23) | ((mant as u32) << 13)
860    };
861    f32::from_bits(f32_bits)
862}
863
864#[cfg(test)]
865mod tests {
866    use super::*;
867    use crate::multimodal::{GemmaAudioConfig, GemmaMultimodalConfig, GemmaVisionConfig};
868
869    fn tiny_cfg() -> GemmaMultimodalConfig {
870        GemmaMultimodalConfig {
871            vision: Some(GemmaVisionConfig {
872                patch_size: 2,
873                model_patch_size: 4,
874                mm_embed_dim: 8,
875                mm_posemb_size: 16,
876                num_soft_tokens: 4,
877                output_proj_dims: 8,
878                pooling_kernel_size: 1,
879                rms_norm_eps: 1e-6,
880            }),
881            audio: Some(GemmaAudioConfig {
882                hidden_size: 4,
883                audio_embed_dim: 4,
884                audio_samples_per_token: 8,
885                output_proj_dims: 4,
886                rms_norm_eps: 1e-6,
887            }),
888            image_token_id: Some(100),
889            audio_token_id: Some(200),
890            boi_token_id: Some(80),
891            eoi_token_id: Some(81),
892            boa_token_id: Some(90),
893            eoa_token_index: Some(91),
894            ..Default::default()
895        }
896    }
897
898    fn synthetic_weights(num_patches: usize, num_frames: usize) -> MultimodalWeights {
899        let v = GemmaVisionConfig {
900            patch_size: 2,
901            model_patch_size: 4,
902            mm_embed_dim: 8,
903            mm_posemb_size: 16,
904            num_soft_tokens: 4,
905            output_proj_dims: 8,
906            pooling_kernel_size: 1,
907            rms_norm_eps: 1e-6,
908        };
909        let a = GemmaAudioConfig {
910            hidden_size: 4,
911            audio_embed_dim: 4,
912            audio_samples_per_token: 8,
913            output_proj_dims: 4,
914            rms_norm_eps: 1e-6,
915        };
916        let patch_features = v.patch_size * v.patch_size * 3;
917        let mut w = MultimodalWeights::empty();
918        // All-zero weights are fine for "graph runs end-to-end" — we
919        // just need the matrices to have the right shapes.
920        w.insert(
921            "vision_tower.embed.weight",
922            vec![0.01f32; patch_features * v.mm_embed_dim],
923        );
924        w.insert(
925            "vision_tower.pos_embed.weight",
926            vec![0.0f32; num_patches * v.mm_embed_dim],
927        );
928        w.insert("vision_tower.norm.weight", vec![0.0f32; v.mm_embed_dim]);
929        w.insert(
930            "vision_tower.soft_token.weight",
931            // Patch-axis reducer: [num_patches, num_soft_tokens].
932            vec![0.01f32; num_patches * v.num_soft_tokens],
933        );
934        w.insert(
935            "vision_tower.lm_proj.weight",
936            // Feature projection: [mm_embed_dim, output_proj_dims].
937            vec![0.01f32; v.mm_embed_dim * v.output_proj_dims],
938        );
939        w.insert(
940            "audio_tower.embed.weight",
941            vec![0.01f32; a.audio_samples_per_token * a.audio_embed_dim],
942        );
943        w.insert("audio_tower.norm.weight", vec![0.0f32; a.audio_embed_dim]);
944        // lm_hidden in the tiny test is 8 (= mm_embed_dim).
945        w.insert(
946            "audio_tower.lm_proj.weight",
947            vec![0.01f32; a.audio_embed_dim * 8],
948        );
949        let _ = num_frames;
950        w
951    }
952
953    #[test]
954    fn runner_compiles_and_runs_vision_projector_on_cpu() {
955        let cfg = tiny_cfg();
956        let num_patches = 4;
957        let mut runner = GemmaMultimodalRunner::new(cfg, 8, Device::Cpu, None, None).unwrap();
958        let weights = synthetic_weights(num_patches, 0);
959        runner.compile_vision(num_patches).unwrap();
960        runner.load_vision_weights(&weights).unwrap();
961        // 4 patches × 2*2*3 = 12 features per patch.
962        let patches = vec![0.5f32; num_patches * 12];
963        let out = runner.project_image_patches(&patches).unwrap();
964        // Output shape: [num_soft_tokens=4, output_proj_dims=8] = 32.
965        assert_eq!(out.len(), 4 * 8);
966    }
967
968    #[test]
969    fn runner_compiles_and_runs_audio_projector_on_cpu() {
970        let cfg = tiny_cfg();
971        let num_frames = 2;
972        let mut runner = GemmaMultimodalRunner::new(cfg, 8, Device::Cpu, None, None).unwrap();
973        let weights = synthetic_weights(0, num_frames);
974        runner.compile_audio(num_frames).unwrap();
975        runner.load_audio_weights(&weights).unwrap();
976        // 2 frames × 8 samples per frame.
977        let frames = vec![0.1f32; num_frames * 8];
978        let out = runner.project_audio_samples(&frames, &weights).unwrap();
979        // [num_frames=2, lm_hidden=8] = 16 floats.
980        assert_eq!(out.len(), 2 * 8);
981    }
982
983    /// Encode an f32 as its bf16 byte representation (top 16 bits).
984    fn f32_to_bf16_bytes(f: f32) -> [u8; 2] {
985        let bits = f.to_bits();
986        let bf16 = ((bits >> 16) & 0xffff) as u16;
987        bf16.to_le_bytes()
988    }
989
990    #[test]
991    fn mmproj_gguf_loader_rejects_missing_file() {
992        // Sanity: missing file surfaces as an open error rather than
993        // panicking. Real round-trip is exercised by callers that
994        // pass an actual llama.cpp-produced `mmproj.gguf`.
995        match MultimodalWeights::from_mmproj_gguf("/nonexistent/mmproj.gguf") {
996            Ok(_) => panic!("expected error for missing mmproj path"),
997            Err(err) => {
998                let msg = format!("{err:#}");
999                assert!(
1000                    msg.contains("mmproj") || msg.contains("open") || msg.contains("No such file"),
1001                    "unexpected error message: {msg}"
1002                );
1003            }
1004        }
1005    }
1006
1007    #[test]
1008    fn weights_loader_decodes_bf16_safetensors() {
1009        // bf16 keeps top 16 bits of an f32, so round-trip values are
1010        // exact at the bf16-precision rung. Pick values whose lower
1011        // 16 mantissa bits are zero to avoid quantization here.
1012        let originals: Vec<f32> = vec![0.0, 1.0, -1.0, 0.5];
1013        let mut payload: Vec<u8> = Vec::with_capacity(originals.len() * 2);
1014        for &f in &originals {
1015            payload.extend_from_slice(&f32_to_bf16_bytes(f));
1016        }
1017        let header = format!(
1018            r#"{{"audio_tower.norm.weight":{{"dtype":"BF16","shape":[4],"data_offsets":[0,{}]}}}}"#,
1019            payload.len(),
1020        );
1021        let header_bytes = header.as_bytes();
1022        let mut buf = Vec::new();
1023        buf.extend_from_slice(&(header_bytes.len() as u64).to_le_bytes());
1024        buf.extend_from_slice(header_bytes);
1025        buf.extend_from_slice(&payload);
1026        let tmp = std::env::temp_dir().join("rlx_gemma_mm_bf16_test.safetensors");
1027        std::fs::write(&tmp, &buf).unwrap();
1028        let w = MultimodalWeights::from_safetensors(&tmp).unwrap();
1029        std::fs::remove_file(&tmp).ok();
1030        let decoded = w.get("audio_tower.norm.weight").unwrap();
1031        assert_eq!(decoded.len(), originals.len());
1032        for (got, want) in decoded.iter().zip(originals.iter()) {
1033            assert!(
1034                (got - want).abs() < 1e-6,
1035                "bf16 decode: got {got} want {want}"
1036            );
1037        }
1038    }
1039
1040    #[test]
1041    fn weights_loader_decodes_f32_safetensors() {
1042        // safetensors header is JSON; build one in-memory.
1043        let payload: Vec<u8> = [1.0f32, 2.0, 3.0, 4.0]
1044            .iter()
1045            .flat_map(|f| f.to_le_bytes())
1046            .collect();
1047        let header = format!(
1048            r#"{{"vision_tower.norm.weight":{{"dtype":"F32","shape":[4],"data_offsets":[0,{}]}}}}"#,
1049            payload.len(),
1050        );
1051        let header_bytes = header.as_bytes();
1052        let mut buf = Vec::new();
1053        buf.extend_from_slice(&(header_bytes.len() as u64).to_le_bytes());
1054        buf.extend_from_slice(header_bytes);
1055        buf.extend_from_slice(&payload);
1056        // Write to a tmp file and round-trip through from_safetensors.
1057        let tmp = std::env::temp_dir().join("rlx_gemma_multimodal_test.safetensors");
1058        std::fs::write(&tmp, &buf).unwrap();
1059        let w = MultimodalWeights::from_safetensors(&tmp).unwrap();
1060        std::fs::remove_file(&tmp).ok();
1061        assert_eq!(
1062            w.get("vision_tower.norm.weight").unwrap(),
1063            &[1.0, 2.0, 3.0, 4.0]
1064        );
1065    }
1066
1067    #[test]
1068    fn tokenize_prompt_derives_slot_counts() {
1069        let cfg = tiny_cfg();
1070        let runner = GemmaMultimodalRunner::new(cfg, 8, Device::Cpu, None, None).unwrap();
1071        // num_soft_tokens=4 (per vision config), audio_per=8.
1072        let encode = |s: &str| -> Result<Vec<u32>> { Ok(s.bytes().map(|b| b as u32).collect()) };
1073        // 16 audio samples → ceil(16/8)=2 audio tokens.
1074        let out = runner
1075            .tokenize_prompt("hi <image> see <audio>", &[4], &[16], &[], encode)
1076            .unwrap();
1077        // "hi " + boi + image*4 + eoi + " see " + boa + audio*2 + eoa
1078        let mut expected: Vec<u32> = b"hi ".iter().map(|b| *b as u32).collect();
1079        expected.extend([80, 100, 100, 100, 100, 81]);
1080        expected.extend(b" see ".iter().map(|b| *b as u32));
1081        expected.extend([90, 200, 200, 91]);
1082        assert_eq!(out, expected);
1083    }
1084
1085    #[test]
1086    fn fuse_text_and_media_replaces_in_order() {
1087        let cfg = tiny_cfg();
1088        let runner = GemmaMultimodalRunner::new(cfg, 4, Device::Cpu, None, None).unwrap();
1089        let mut text = vec![0.0f32; 4 * 4]; // 4 tokens, hidden=4
1090        let ids = [10, 100, 200, 11];
1091        let img = vec![7.0f32; 4];
1092        let aud = vec![9.0f32; 4];
1093        runner
1094            .fuse_text_and_media(&mut text, &ids, &img, &aud, &[])
1095            .unwrap();
1096        assert_eq!(&text[4..8], &[7.0; 4]);
1097        assert_eq!(&text[8..12], &[9.0; 4]);
1098    }
1099}