1use anyhow::{Context, Result, anyhow, bail};
17use rlx_runtime::{CompiledGraph, Device, Session};
18use safetensors::SafeTensors;
19use std::collections::HashMap;
20use std::path::{Path, PathBuf};
21
22use crate::multimodal::{
23 AUDIO_MARKER, AUDIO_MARKER_HF, GemmaAudioConfig, GemmaMultimodalConfig, GemmaVisionConfig,
24 IMAGE_MARKER, IMAGE_MARKER_HF, MediaSlot, VIDEO_MARKER, VIDEO_MARKER_HF,
25 build_audio_projection_graph, build_vision_projection_graph, frame_audio_samples,
26 fuse_multimodal_embeddings, load_image_patches, tokenize_with_media,
27};
28use crate::unified_preprocess::{
29 compute_num_soft_tokens_from_size, factorized_pos_bias, load_unified_image,
30 prepare_unified_audio_samples, strip_valid_vision_rows, unified_audio_token_count,
31};
32use crate::unified_projector::{
33 build_unified_audio_graph, build_unified_vision_graph, is_unified_vision_weights,
34};
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
38pub enum ProjectorLayout {
39 #[default]
41 LegacyPool,
42 Unified,
44}
45
46struct VisionStage {
49 compiled: CompiledGraph,
50 num_patches: usize,
51 output_dim: usize,
52 unified: bool,
53}
54
55impl VisionStage {
56 fn output_shape(&self) -> (usize, usize) {
57 (self.num_patches, self.output_dim)
58 }
59}
60
61struct AudioStage {
62 compiled: CompiledGraph,
63 num_frames: usize,
64 output_dim: usize,
65 unified: bool,
66}
67
68impl AudioStage {
69 fn output_shape(&self) -> (usize, usize) {
70 (self.num_frames, self.output_dim)
71 }
72}
73
74pub struct GemmaMultimodalRunner {
81 cfg: GemmaMultimodalConfig,
82 lm_hidden: usize,
83 device: Device,
84 layout: ProjectorLayout,
85 max_soft_tokens: usize,
86 max_audio_tokens: usize,
87 vision: Option<VisionStage>,
88 audio: Option<AudioStage>,
89}
90
91impl GemmaMultimodalRunner {
92 pub fn new(
97 cfg: GemmaMultimodalConfig,
98 lm_hidden: usize,
99 device: Device,
100 max_soft_tokens: Option<usize>,
101 max_audio_tokens: Option<usize>,
102 ) -> Result<Self> {
103 let max_soft = max_soft_tokens.unwrap_or_else(|| {
104 cfg.vision
105 .as_ref()
106 .map(|v| v.num_soft_tokens)
107 .unwrap_or(280)
108 });
109 Ok(Self {
110 cfg,
111 lm_hidden,
112 device,
113 layout: ProjectorLayout::LegacyPool,
114 max_soft_tokens: max_soft,
115 max_audio_tokens: max_audio_tokens.unwrap_or(750),
116 vision: None,
117 audio: None,
118 })
119 }
120
121 pub fn vision_output_shape(&self) -> Option<(usize, usize)> {
125 self.vision.as_ref().map(|s| s.output_shape())
126 }
127
128 pub fn audio_output_shape(&self) -> Option<(usize, usize)> {
131 self.audio.as_ref().map(|s| s.output_shape())
132 }
133
134 pub fn config(&self) -> &GemmaMultimodalConfig {
135 &self.cfg
136 }
137
138 pub fn lm_hidden(&self) -> usize {
139 self.lm_hidden
140 }
141
142 pub fn layout(&self) -> ProjectorLayout {
143 self.layout
144 }
145
146 pub fn is_unified(&self) -> bool {
147 self.layout == ProjectorLayout::Unified
148 }
149
150 fn sync_layout(&mut self, weights: &MultimodalWeights) {
151 self.layout = weights.layout();
152 }
153
154 fn vision_cfg(&self) -> Result<&GemmaVisionConfig> {
155 self.cfg
156 .vision
157 .as_ref()
158 .ok_or_else(|| anyhow!("vision config missing — model is text/audio only"))
159 }
160
161 fn audio_cfg(&self) -> Result<&GemmaAudioConfig> {
162 self.cfg
163 .audio
164 .as_ref()
165 .ok_or_else(|| anyhow!("audio config missing — model is text/vision only"))
166 }
167
168 pub fn compile_vision(&mut self, num_patches: usize) -> Result<()> {
172 let vcfg = self.vision_cfg()?.clone();
173 let unified = self.layout == ProjectorLayout::Unified;
174 let g = if unified {
175 build_unified_vision_graph(num_patches, &vcfg)?
176 } else {
177 build_vision_projection_graph(1, num_patches, &vcfg)?
178 };
179 let session = Session::new(self.device);
180 let compiled = session
181 .compile_hir(g.hir)
182 .map_err(|e| anyhow!("vision projector lower failed: {e:?}"))?;
183 self.vision = Some(VisionStage {
184 compiled,
185 num_patches,
186 output_dim: if unified {
187 self.lm_hidden
188 } else {
189 vcfg.output_proj_dims
190 },
191 unified,
192 });
193 Ok(())
194 }
195
196 pub fn compile_audio(&mut self, num_frames: usize) -> Result<()> {
197 let acfg = self.audio_cfg()?.clone();
198 let unified = self.layout == ProjectorLayout::Unified;
199 let g = if unified {
200 build_unified_audio_graph(num_frames, &acfg, self.lm_hidden)?
201 } else {
202 build_audio_projection_graph(1, num_frames, &acfg, self.lm_hidden)?
203 };
204 let session = Session::new(self.device);
205 let compiled = session
206 .compile_hir(g.hir)
207 .map_err(|e| anyhow!("audio projector lower failed: {e:?}"))?;
208 self.audio = Some(AudioStage {
209 compiled,
210 num_frames,
211 output_dim: self.lm_hidden,
212 unified,
213 });
214 Ok(())
215 }
216
217 pub fn reconfigure(
221 &mut self,
222 num_patches: Option<usize>,
223 num_frames: Option<usize>,
224 weights: &MultimodalWeights,
225 ) -> Result<()> {
226 if let Some(n) = num_patches {
227 self.compile_vision(n)?;
228 self.load_vision_weights(weights)?;
229 }
230 if let Some(n) = num_frames {
231 self.compile_audio(n)?;
232 self.load_audio_weights(weights)?;
233 }
234 Ok(())
235 }
236
237 pub fn load_vision_weights(&mut self, weights: &MultimodalWeights) -> Result<()> {
239 self.sync_layout(weights);
240 let stage = self
241 .vision
242 .as_mut()
243 .ok_or_else(|| anyhow!("vision projector not compiled — call compile_vision first"))?;
244 if stage.unified {
245 let vcfg = self.cfg.vision.as_ref().unwrap();
246 let d = vcfg.mm_embed_dim;
247 for key in [
248 "model.vision_embedder.patch_ln1.weight",
249 "model.vision_embedder.patch_ln1.bias",
250 "model.vision_embedder.patch_dense.weight",
251 "model.vision_embedder.patch_dense.bias",
252 "model.vision_embedder.patch_ln2.weight",
253 "model.vision_embedder.patch_ln2.bias",
254 "model.vision_embedder.pos_norm.weight",
255 "model.vision_embedder.pos_norm.bias",
256 "model.embed_vision.embedding_projection.weight",
257 ] {
258 let data = weights.get_linear(key)?;
259 stage.compiled.set_param(key, &data);
260 }
261 stage.compiled.set_param("unified.ones", &vec![1.0f32; d]);
262 stage
263 .compiled
264 .set_param("unified.zero_beta", &vec![0.0f32; d]);
265 } else {
266 for key in [
267 "vision_tower.embed.weight",
268 "vision_tower.pos_embed.weight",
269 "vision_tower.norm.weight",
270 "vision_tower.soft_token.weight",
271 "vision_tower.lm_proj.weight",
272 ] {
273 let data = weights.get(key)?;
274 stage.compiled.set_param(key, data);
275 }
276 let vcfg = self.cfg.vision.as_ref().unwrap();
277 stage
278 .compiled
279 .set_param("vision_tower.ones", &vec![1.0f32; vcfg.mm_embed_dim]);
280 stage
281 .compiled
282 .set_param("vision_tower.zero_beta", &vec![0.0f32; vcfg.mm_embed_dim]);
283 }
284 Ok(())
285 }
286
287 pub fn load_audio_weights(&mut self, weights: &MultimodalWeights) -> Result<()> {
289 self.sync_layout(weights);
290 let stage = self
291 .audio
292 .as_mut()
293 .ok_or_else(|| anyhow!("audio projector not compiled — call compile_audio first"))?;
294 if stage.unified {
295 let acfg = self.cfg.audio.as_ref().unwrap();
296 let d = acfg.audio_embed_dim;
297 let data = weights.get_linear("model.embed_audio.embedding_projection.weight")?;
298 stage
299 .compiled
300 .set_param("model.embed_audio.embedding_projection.weight", &data);
301 stage
302 .compiled
303 .set_param("unified.audio.ones", &vec![1.0f32; d]);
304 stage
305 .compiled
306 .set_param("unified.audio.zero_beta", &vec![0.0f32; d]);
307 } else {
308 for key in [
309 "audio_tower.embed.weight",
310 "audio_tower.norm.weight",
311 "audio_tower.lm_proj.weight",
312 ] {
313 let data = weights.get(key)?;
314 stage.compiled.set_param(key, data);
315 }
316 let acfg = self.cfg.audio.as_ref().unwrap();
317 stage
318 .compiled
319 .set_param("audio_tower.ones", &vec![1.0f32; acfg.audio_embed_dim]);
320 stage
321 .compiled
322 .set_param("audio_tower.zero_beta", &vec![0.0f32; acfg.audio_embed_dim]);
323 }
324 Ok(())
325 }
326
327 pub fn project_image_patches(&mut self, patches: &[f32]) -> Result<Vec<f32>> {
329 self.project_image_with_pos(patches, None)
330 }
331
332 pub fn project_image_with_pos(
334 &mut self,
335 patches: &[f32],
336 pos_bias: Option<&[f32]>,
337 ) -> Result<Vec<f32>> {
338 let stage = self
339 .vision
340 .as_mut()
341 .ok_or_else(|| anyhow!("vision projector not compiled"))?;
342 let outs = if stage.unified {
343 let pos = pos_bias.ok_or_else(|| anyhow!("unified vision requires pos_bias"))?;
344 stage
345 .compiled
346 .run(&[("patches", patches), ("pos_bias", pos)])
347 } else {
348 stage.compiled.run(&[("patches", patches)])
349 };
350 outs.into_iter()
351 .next()
352 .ok_or_else(|| anyhow!("vision projector returned no outputs"))
353 }
354
355 pub fn project_image_file(
361 &mut self,
362 path: impl AsRef<Path>,
363 weights: &MultimodalWeights,
364 max_side_patches: usize,
365 ) -> Result<Vec<f32>> {
366 self.sync_layout(weights);
367 if self.layout == ProjectorLayout::Unified {
368 let vcfg = self.vision_cfg()?.clone();
369 let img = load_unified_image(
370 path.as_ref(),
371 vcfg.patch_size,
372 vcfg.pooling_kernel_size,
373 self.max_soft_tokens,
374 )?;
375 let num_slots = self.max_soft_tokens;
376 let need_recompile = match &self.vision {
377 Some(s) => s.num_patches != num_slots || !s.unified,
378 None => true,
379 };
380 if need_recompile {
381 self.compile_vision(num_slots)?;
382 self.load_vision_weights(weights)?;
383 }
384 let pos_table = weights.get("model.vision_embedder.pos_embedding")?;
385 let pos_bias = factorized_pos_bias(
386 pos_table,
387 vcfg.mm_posemb_size,
388 vcfg.mm_embed_dim,
389 &img.positions,
390 );
391 let projected = self.project_image_with_pos(&img.patches, Some(&pos_bias))?;
392 return Ok(strip_valid_vision_rows(
393 &projected,
394 &img.positions,
395 self.lm_hidden,
396 ));
397 }
398
399 let vcfg = self.vision_cfg()?.clone();
400 let (patches, grid_h, grid_w) =
401 load_image_patches(path, vcfg.patch_size, max_side_patches)?;
402 let num_patches = grid_h * grid_w;
403 let need_recompile = match &self.vision {
404 Some(s) => s.num_patches != num_patches || s.unified,
405 None => true,
406 };
407 if need_recompile {
408 self.compile_vision(num_patches)?;
409 self.load_vision_weights(weights)?;
410 }
411 self.project_image_patches(&patches)
412 }
413
414 pub fn project_video_frame(
418 &mut self,
419 path: impl AsRef<Path>,
420 weights: &MultimodalWeights,
421 ) -> Result<Vec<f32>> {
422 let prev = self.max_soft_tokens;
423 self.max_soft_tokens = 70;
424 let out = self.project_image_file(path, weights, 32);
425 self.max_soft_tokens = prev;
426 out
427 }
428
429 pub fn image_soft_token_count(&self, path: impl AsRef<Path>) -> Result<usize> {
431 let vcfg = self.vision_cfg()?.clone();
432 let img =
433 image::open(path.as_ref()).map_err(|e| anyhow!("decode {:?}: {e}", path.as_ref()))?;
434 let (w, h) = (img.width() as usize, img.height() as usize);
435 compute_num_soft_tokens_from_size(
436 h,
437 w,
438 vcfg.patch_size,
439 vcfg.pooling_kernel_size,
440 self.max_soft_tokens,
441 )
442 }
443
444 pub fn project_audio_samples(
445 &mut self,
446 samples: &[f32],
447 weights: &MultimodalWeights,
448 ) -> Result<Vec<f32>> {
449 self.sync_layout(weights);
450 let acfg = self.audio_cfg()?.clone();
451 let prepared = if self.layout == ProjectorLayout::Unified {
452 prepare_unified_audio_samples(
453 samples,
454 acfg.audio_samples_per_token,
455 self.max_audio_tokens,
456 )
457 } else {
458 samples.to_vec()
459 };
460 let (frames, num_frames) = frame_audio_samples(&prepared, acfg.audio_samples_per_token)?;
461 let effective_frames = if self.layout == ProjectorLayout::Unified {
462 unified_audio_token_count(
463 samples
464 .len()
465 .min(crate::unified_preprocess::MAX_AUDIO_SAMPLES),
466 acfg.audio_samples_per_token,
467 self.max_audio_tokens,
468 )
469 } else {
470 num_frames
471 };
472 let need_recompile = match &self.audio {
473 Some(s) => {
474 s.num_frames != num_frames || s.unified != (self.layout == ProjectorLayout::Unified)
475 }
476 None => true,
477 };
478 if need_recompile {
479 self.compile_audio(num_frames)?;
480 self.load_audio_weights(weights)?;
481 }
482 let stage = self
483 .audio
484 .as_mut()
485 .ok_or_else(|| anyhow!("audio projector not compiled"))?;
486 let outs = stage.compiled.run(&[("frames", &frames[..])]);
487 let projected = outs
488 .into_iter()
489 .next()
490 .ok_or_else(|| anyhow!("audio projector returned no outputs"))?;
491 if self.layout == ProjectorLayout::Unified {
492 Ok(projected[..effective_frames * self.lm_hidden].to_vec())
493 } else {
494 Ok(projected)
495 }
496 }
497
498 pub fn project_audio_file(
500 &mut self,
501 path: impl AsRef<Path>,
502 weights: &MultimodalWeights,
503 ) -> Result<Vec<f32>> {
504 self.sync_layout(weights);
505 let samples = crate::multimodal::load_wav_mono_16khz(path)?;
506 if self.audio.is_none() {
507 self.compile_audio(
508 samples
509 .len()
510 .div_ceil(self.audio_cfg()?.audio_samples_per_token),
511 )?;
512 self.load_audio_weights(weights)?;
513 }
514 let was_present = self.audio.is_some();
515 let result = self.project_audio_samples(&samples, weights);
516 if !was_present {
517 self.load_audio_weights(weights)?;
518 }
519 result
520 }
521
522 pub fn fuse_text_and_media(
523 &self,
524 text_embeds: &mut [f32],
525 token_ids: &[u32],
526 image_embeds: &[f32],
527 audio_embeds: &[f32],
528 video_embeds: &[f32],
529 ) -> Result<()> {
530 fuse_multimodal_embeddings(
531 text_embeds,
532 token_ids,
533 self.lm_hidden,
534 &self.cfg,
535 image_embeds,
536 audio_embeds,
537 video_embeds,
538 )
539 }
540
541 pub fn tokenize_prompt<F>(
543 &self,
544 prompt: &str,
545 image_soft_counts: &[usize],
546 audio_sample_lengths: &[usize],
547 video_soft_counts: &[usize],
548 encode_fn: F,
549 ) -> Result<Vec<u32>>
550 where
551 F: FnMut(&str) -> Result<Vec<u32>>,
552 {
553 let audio_per = self
554 .cfg
555 .audio
556 .as_ref()
557 .map(|a| a.audio_samples_per_token)
558 .unwrap_or(640);
559
560 let mut slots: Vec<MediaSlot> = Vec::new();
561 let mut img_idx = 0usize;
562 let mut aud_idx = 0usize;
563 let mut vid_idx = 0usize;
564 let mut cursor = 0usize;
565 while cursor <= prompt.len() {
566 let remainder = &prompt[cursor..];
567 let next = find_next_marker(remainder);
568 match next {
569 Some((off, kind, marker_len)) => {
570 match kind {
571 MarkerKind::Image => {
572 let count = *image_soft_counts.get(img_idx).ok_or_else(|| {
573 anyhow!(
574 "not enough image_soft_counts for marker at offset {cursor}"
575 )
576 })?;
577 slots.push(MediaSlot::Image { count });
578 img_idx += 1;
579 }
580 MarkerKind::Audio => {
581 let n = *audio_sample_lengths.get(aud_idx).ok_or_else(|| {
582 anyhow!(
583 "not enough audio_sample_lengths for marker at offset {cursor}"
584 )
585 })?;
586 let count = if audio_per == 640 {
587 unified_audio_token_count(
588 n.min(crate::unified_preprocess::MAX_AUDIO_SAMPLES),
589 audio_per,
590 self.max_audio_tokens,
591 )
592 } else {
593 n.div_ceil(audio_per).max(1)
594 };
595 slots.push(MediaSlot::Audio { count });
596 aud_idx += 1;
597 }
598 MarkerKind::Video => {
599 let count = *video_soft_counts.get(vid_idx).ok_or_else(|| {
600 anyhow!(
601 "not enough video_soft_counts for marker at offset {cursor}"
602 )
603 })?;
604 slots.push(MediaSlot::Video { count });
605 vid_idx += 1;
606 }
607 }
608 cursor += off + marker_len;
609 }
610 None => break,
611 }
612 }
613 if img_idx != image_soft_counts.len() {
614 bail!(
615 "prompt has {img_idx} image markers but {} image_soft_counts supplied",
616 image_soft_counts.len()
617 );
618 }
619 if aud_idx != audio_sample_lengths.len() {
620 bail!(
621 "prompt has {aud_idx} audio markers but {} audio lengths supplied",
622 audio_sample_lengths.len()
623 );
624 }
625 if vid_idx != video_soft_counts.len() {
626 bail!(
627 "prompt has {vid_idx} video markers but {} video_soft_counts supplied",
628 video_soft_counts.len()
629 );
630 }
631 tokenize_with_media(prompt, &slots, &self.cfg, encode_fn)
632 }
633}
634
635#[derive(Clone, Copy)]
636enum MarkerKind {
637 Image,
638 Audio,
639 Video,
640}
641
642fn find_next_marker(prompt: &str) -> Option<(usize, MarkerKind, usize)> {
643 let candidates: &[(&str, MarkerKind)] = &[
644 (IMAGE_MARKER_HF, MarkerKind::Image),
645 (IMAGE_MARKER, MarkerKind::Image),
646 (AUDIO_MARKER_HF, MarkerKind::Audio),
647 (AUDIO_MARKER, MarkerKind::Audio),
648 (VIDEO_MARKER_HF, MarkerKind::Video),
649 (VIDEO_MARKER, MarkerKind::Video),
650 ];
651 let mut best: Option<(usize, MarkerKind, usize)> = None;
652 for &(m, kind) in candidates {
653 if let Some(i) = prompt.find(m) {
654 if best.map(|(bi, _, _)| i < bi).unwrap_or(true) {
655 best = Some((i, kind, m.len()));
656 }
657 }
658 }
659 best
660}
661
662pub struct MultimodalWeights {
668 data: HashMap<String, Vec<f32>>,
669 pub source: Option<PathBuf>,
671}
672
673impl MultimodalWeights {
674 pub fn empty() -> Self {
675 Self {
676 data: HashMap::new(),
677 source: None,
678 }
679 }
680
681 pub fn from_safetensors(path: impl AsRef<Path>) -> Result<Self> {
685 let path = path.as_ref();
686 let bytes = std::fs::read(path).with_context(|| format!("read {path:?}"))?;
687 let st = SafeTensors::deserialize(&bytes)
688 .map_err(|e| anyhow!("parse safetensors {path:?}: {e}"))?;
689 let mut data = HashMap::new();
690 for (name, view) in st.tensors() {
691 if !is_multimodal_tensor_key(&name) {
692 continue;
693 }
694 let shape: Vec<usize> = view.shape().to_vec();
695 let mut f32_data = tensor_to_f32(&view).with_context(|| format!("decode {name}"))?;
696 if should_transpose_hf_linear(&name, &shape) {
697 f32_data = transpose_2d(&f32_data, shape[0], shape[1]);
698 }
699 data.insert(name, f32_data);
700 }
701 Ok(Self {
702 data,
703 source: Some(path.to_path_buf()),
704 })
705 }
706
707 pub fn from_mmproj_gguf(path: impl AsRef<Path>) -> Result<Self> {
717 let path = path.as_ref();
718 let file = rlx_gguf::GgufFile::from_path(path)
719 .with_context(|| format!("open mmproj GGUF {path:?}"))?;
720 let mut data = HashMap::new();
721 let keys: Vec<String> = file.keys().map(str::to_string).collect();
722 for name in keys {
723 if !is_multimodal_tensor_key(&name) {
724 continue;
725 }
726 let (decoded, shape) = file
727 .dequant_f32(&name)
728 .with_context(|| format!("dequant `{name}`"))?;
729 let mut f32_data = decoded;
730 if shape.len() == 2 && should_transpose_hf_linear(&name, &shape) {
731 f32_data = transpose_2d(&f32_data, shape[0], shape[1]);
732 }
733 data.insert(name, f32_data);
734 }
735 Ok(Self {
736 data,
737 source: Some(path.to_path_buf()),
738 })
739 }
740
741 pub fn insert(&mut self, key: impl Into<String>, data: Vec<f32>) {
744 self.data.insert(key.into(), data);
745 }
746
747 pub fn get(&self, key: &str) -> Result<&[f32]> {
748 self.data
749 .get(key)
750 .map(|v| v.as_slice())
751 .ok_or_else(|| anyhow!("multimodal weight missing: `{key}`"))
752 }
753
754 pub fn get_linear(&self, key: &str) -> Result<Vec<f32>> {
756 Ok(self.get(key)?.to_vec())
757 }
758
759 pub fn layout(&self) -> ProjectorLayout {
760 if is_unified_vision_weights(self.keys()) {
761 ProjectorLayout::Unified
762 } else {
763 ProjectorLayout::LegacyPool
764 }
765 }
766
767 pub fn keys(&self) -> impl Iterator<Item = &str> {
768 self.data.keys().map(|s| s.as_str())
769 }
770}
771
772fn is_multimodal_tensor_key(name: &str) -> bool {
773 name.starts_with("vision_tower.")
774 || name.starts_with("audio_tower.")
775 || name.starts_with("model.vision_embedder.")
776 || name.starts_with("model.embed_vision.")
777 || name.starts_with("model.embed_audio.")
778}
779
780fn should_transpose_hf_linear(name: &str, shape: &[usize]) -> bool {
781 shape.len() == 2
782 && (name.contains("patch_dense.weight") || name.contains("embedding_projection.weight"))
783}
784
785fn transpose_2d(data: &[f32], rows: usize, cols: usize) -> Vec<f32> {
786 let mut out = vec![0f32; rows * cols];
787 for r in 0..rows {
788 for c in 0..cols {
789 out[c * rows + r] = data[r * cols + c];
790 }
791 }
792 out
793}
794
795fn tensor_to_f32(view: &safetensors::tensor::TensorView<'_>) -> Result<Vec<f32>> {
796 use safetensors::tensor::Dtype;
797 match view.dtype() {
798 Dtype::F32 => {
799 let raw = view.data();
800 if raw.len() % 4 != 0 {
801 bail!("F32 tensor data length not multiple of 4");
802 }
803 let mut out = Vec::with_capacity(raw.len() / 4);
804 for ch in raw.chunks_exact(4) {
805 out.push(f32::from_le_bytes([ch[0], ch[1], ch[2], ch[3]]));
806 }
807 Ok(out)
808 }
809 Dtype::F16 => {
810 let raw = view.data();
811 if raw.len() % 2 != 0 {
812 bail!("F16 tensor data length not multiple of 2");
813 }
814 let mut out = Vec::with_capacity(raw.len() / 2);
815 for ch in raw.chunks_exact(2) {
816 let bits = u16::from_le_bytes([ch[0], ch[1]]);
817 out.push(f16_to_f32(bits));
818 }
819 Ok(out)
820 }
821 Dtype::BF16 => {
822 let raw = view.data();
823 if raw.len() % 2 != 0 {
824 bail!("BF16 tensor data length not multiple of 2");
825 }
826 let mut out = Vec::with_capacity(raw.len() / 2);
827 for ch in raw.chunks_exact(2) {
828 let bits = u32::from(u16::from_le_bytes([ch[0], ch[1]])) << 16;
830 out.push(f32::from_bits(bits));
831 }
832 Ok(out)
833 }
834 other => bail!("unsupported tensor dtype for multimodal weight: {other:?}"),
835 }
836}
837
838fn f16_to_f32(bits: u16) -> f32 {
839 let sign = (bits >> 15) & 0x1;
840 let exp = (bits >> 10) & 0x1f;
841 let mant = bits & 0x3ff;
842 let f32_bits = if exp == 0 {
843 if mant == 0 {
844 (sign as u32) << 31
845 } else {
846 let mut m = mant as u32;
848 let mut e: i32 = -14;
849 while (m & 0x400) == 0 {
850 m <<= 1;
851 e -= 1;
852 }
853 m &= 0x3ff;
854 ((sign as u32) << 31) | (((e + 127) as u32) << 23) | (m << 13)
855 }
856 } else if exp == 0x1f {
857 ((sign as u32) << 31) | (0xff << 23) | ((mant as u32) << 13)
858 } else {
859 ((sign as u32) << 31) | (((exp as i32 + 127 - 15) as u32) << 23) | ((mant as u32) << 13)
860 };
861 f32::from_bits(f32_bits)
862}
863
864#[cfg(test)]
865mod tests {
866 use super::*;
867 use crate::multimodal::{GemmaAudioConfig, GemmaMultimodalConfig, GemmaVisionConfig};
868
869 fn tiny_cfg() -> GemmaMultimodalConfig {
870 GemmaMultimodalConfig {
871 vision: Some(GemmaVisionConfig {
872 patch_size: 2,
873 model_patch_size: 4,
874 mm_embed_dim: 8,
875 mm_posemb_size: 16,
876 num_soft_tokens: 4,
877 output_proj_dims: 8,
878 pooling_kernel_size: 1,
879 rms_norm_eps: 1e-6,
880 }),
881 audio: Some(GemmaAudioConfig {
882 hidden_size: 4,
883 audio_embed_dim: 4,
884 audio_samples_per_token: 8,
885 output_proj_dims: 4,
886 rms_norm_eps: 1e-6,
887 }),
888 image_token_id: Some(100),
889 audio_token_id: Some(200),
890 boi_token_id: Some(80),
891 eoi_token_id: Some(81),
892 boa_token_id: Some(90),
893 eoa_token_index: Some(91),
894 ..Default::default()
895 }
896 }
897
898 fn synthetic_weights(num_patches: usize, num_frames: usize) -> MultimodalWeights {
899 let v = GemmaVisionConfig {
900 patch_size: 2,
901 model_patch_size: 4,
902 mm_embed_dim: 8,
903 mm_posemb_size: 16,
904 num_soft_tokens: 4,
905 output_proj_dims: 8,
906 pooling_kernel_size: 1,
907 rms_norm_eps: 1e-6,
908 };
909 let a = GemmaAudioConfig {
910 hidden_size: 4,
911 audio_embed_dim: 4,
912 audio_samples_per_token: 8,
913 output_proj_dims: 4,
914 rms_norm_eps: 1e-6,
915 };
916 let patch_features = v.patch_size * v.patch_size * 3;
917 let mut w = MultimodalWeights::empty();
918 w.insert(
921 "vision_tower.embed.weight",
922 vec![0.01f32; patch_features * v.mm_embed_dim],
923 );
924 w.insert(
925 "vision_tower.pos_embed.weight",
926 vec![0.0f32; num_patches * v.mm_embed_dim],
927 );
928 w.insert("vision_tower.norm.weight", vec![0.0f32; v.mm_embed_dim]);
929 w.insert(
930 "vision_tower.soft_token.weight",
931 vec![0.01f32; num_patches * v.num_soft_tokens],
933 );
934 w.insert(
935 "vision_tower.lm_proj.weight",
936 vec![0.01f32; v.mm_embed_dim * v.output_proj_dims],
938 );
939 w.insert(
940 "audio_tower.embed.weight",
941 vec![0.01f32; a.audio_samples_per_token * a.audio_embed_dim],
942 );
943 w.insert("audio_tower.norm.weight", vec![0.0f32; a.audio_embed_dim]);
944 w.insert(
946 "audio_tower.lm_proj.weight",
947 vec![0.01f32; a.audio_embed_dim * 8],
948 );
949 let _ = num_frames;
950 w
951 }
952
953 #[test]
954 fn runner_compiles_and_runs_vision_projector_on_cpu() {
955 let cfg = tiny_cfg();
956 let num_patches = 4;
957 let mut runner = GemmaMultimodalRunner::new(cfg, 8, Device::Cpu, None, None).unwrap();
958 let weights = synthetic_weights(num_patches, 0);
959 runner.compile_vision(num_patches).unwrap();
960 runner.load_vision_weights(&weights).unwrap();
961 let patches = vec![0.5f32; num_patches * 12];
963 let out = runner.project_image_patches(&patches).unwrap();
964 assert_eq!(out.len(), 4 * 8);
966 }
967
968 #[test]
969 fn runner_compiles_and_runs_audio_projector_on_cpu() {
970 let cfg = tiny_cfg();
971 let num_frames = 2;
972 let mut runner = GemmaMultimodalRunner::new(cfg, 8, Device::Cpu, None, None).unwrap();
973 let weights = synthetic_weights(0, num_frames);
974 runner.compile_audio(num_frames).unwrap();
975 runner.load_audio_weights(&weights).unwrap();
976 let frames = vec![0.1f32; num_frames * 8];
978 let out = runner.project_audio_samples(&frames, &weights).unwrap();
979 assert_eq!(out.len(), 2 * 8);
981 }
982
983 fn f32_to_bf16_bytes(f: f32) -> [u8; 2] {
985 let bits = f.to_bits();
986 let bf16 = ((bits >> 16) & 0xffff) as u16;
987 bf16.to_le_bytes()
988 }
989
990 #[test]
991 fn mmproj_gguf_loader_rejects_missing_file() {
992 match MultimodalWeights::from_mmproj_gguf("/nonexistent/mmproj.gguf") {
996 Ok(_) => panic!("expected error for missing mmproj path"),
997 Err(err) => {
998 let msg = format!("{err:#}");
999 assert!(
1000 msg.contains("mmproj") || msg.contains("open") || msg.contains("No such file"),
1001 "unexpected error message: {msg}"
1002 );
1003 }
1004 }
1005 }
1006
1007 #[test]
1008 fn weights_loader_decodes_bf16_safetensors() {
1009 let originals: Vec<f32> = vec![0.0, 1.0, -1.0, 0.5];
1013 let mut payload: Vec<u8> = Vec::with_capacity(originals.len() * 2);
1014 for &f in &originals {
1015 payload.extend_from_slice(&f32_to_bf16_bytes(f));
1016 }
1017 let header = format!(
1018 r#"{{"audio_tower.norm.weight":{{"dtype":"BF16","shape":[4],"data_offsets":[0,{}]}}}}"#,
1019 payload.len(),
1020 );
1021 let header_bytes = header.as_bytes();
1022 let mut buf = Vec::new();
1023 buf.extend_from_slice(&(header_bytes.len() as u64).to_le_bytes());
1024 buf.extend_from_slice(header_bytes);
1025 buf.extend_from_slice(&payload);
1026 let tmp = std::env::temp_dir().join("rlx_gemma_mm_bf16_test.safetensors");
1027 std::fs::write(&tmp, &buf).unwrap();
1028 let w = MultimodalWeights::from_safetensors(&tmp).unwrap();
1029 std::fs::remove_file(&tmp).ok();
1030 let decoded = w.get("audio_tower.norm.weight").unwrap();
1031 assert_eq!(decoded.len(), originals.len());
1032 for (got, want) in decoded.iter().zip(originals.iter()) {
1033 assert!(
1034 (got - want).abs() < 1e-6,
1035 "bf16 decode: got {got} want {want}"
1036 );
1037 }
1038 }
1039
1040 #[test]
1041 fn weights_loader_decodes_f32_safetensors() {
1042 let payload: Vec<u8> = [1.0f32, 2.0, 3.0, 4.0]
1044 .iter()
1045 .flat_map(|f| f.to_le_bytes())
1046 .collect();
1047 let header = format!(
1048 r#"{{"vision_tower.norm.weight":{{"dtype":"F32","shape":[4],"data_offsets":[0,{}]}}}}"#,
1049 payload.len(),
1050 );
1051 let header_bytes = header.as_bytes();
1052 let mut buf = Vec::new();
1053 buf.extend_from_slice(&(header_bytes.len() as u64).to_le_bytes());
1054 buf.extend_from_slice(header_bytes);
1055 buf.extend_from_slice(&payload);
1056 let tmp = std::env::temp_dir().join("rlx_gemma_multimodal_test.safetensors");
1058 std::fs::write(&tmp, &buf).unwrap();
1059 let w = MultimodalWeights::from_safetensors(&tmp).unwrap();
1060 std::fs::remove_file(&tmp).ok();
1061 assert_eq!(
1062 w.get("vision_tower.norm.weight").unwrap(),
1063 &[1.0, 2.0, 3.0, 4.0]
1064 );
1065 }
1066
1067 #[test]
1068 fn tokenize_prompt_derives_slot_counts() {
1069 let cfg = tiny_cfg();
1070 let runner = GemmaMultimodalRunner::new(cfg, 8, Device::Cpu, None, None).unwrap();
1071 let encode = |s: &str| -> Result<Vec<u32>> { Ok(s.bytes().map(|b| b as u32).collect()) };
1073 let out = runner
1075 .tokenize_prompt("hi <image> see <audio>", &[4], &[16], &[], encode)
1076 .unwrap();
1077 let mut expected: Vec<u32> = b"hi ".iter().map(|b| *b as u32).collect();
1079 expected.extend([80, 100, 100, 100, 100, 81]);
1080 expected.extend(b" see ".iter().map(|b| *b as u32));
1081 expected.extend([90, 200, 200, 91]);
1082 assert_eq!(out, expected);
1083 }
1084
1085 #[test]
1086 fn fuse_text_and_media_replaces_in_order() {
1087 let cfg = tiny_cfg();
1088 let runner = GemmaMultimodalRunner::new(cfg, 4, Device::Cpu, None, None).unwrap();
1089 let mut text = vec![0.0f32; 4 * 4]; let ids = [10, 100, 200, 11];
1091 let img = vec![7.0f32; 4];
1092 let aud = vec![9.0f32; 4];
1093 runner
1094 .fuse_text_and_media(&mut text, &ids, &img, &aud, &[])
1095 .unwrap();
1096 assert_eq!(&text[4..8], &[7.0; 4]);
1097 assert_eq!(&text[8..12], &[9.0; 4]);
1098 }
1099}