1use anyhow::{Result, anyhow, bail};
17use rlx_ir::hir::{HirModule, HirMut, HirNodeId};
18use rlx_ir::{DType, HirGraphExt, Shape};
19use serde::Deserialize;
20use std::path::Path;
21use std::str::FromStr;
22
23#[derive(Debug, Clone, Deserialize)]
30#[serde(default)]
31pub struct GemmaVisionConfig {
32 pub patch_size: usize,
34 pub model_patch_size: usize,
37 pub mm_embed_dim: usize,
39 pub mm_posemb_size: usize,
42 pub num_soft_tokens: usize,
44 pub output_proj_dims: usize,
47 pub pooling_kernel_size: usize,
50 pub rms_norm_eps: f64,
52}
53
54impl Default for GemmaVisionConfig {
55 fn default() -> Self {
56 Self {
57 patch_size: 16,
58 model_patch_size: 48,
59 mm_embed_dim: 3840,
60 mm_posemb_size: 1120,
61 num_soft_tokens: 280,
62 output_proj_dims: 3840,
63 pooling_kernel_size: 3,
64 rms_norm_eps: 1e-6,
65 }
66 }
67}
68
69#[derive(Debug, Clone, Deserialize)]
71#[serde(default)]
72pub struct GemmaAudioConfig {
73 pub hidden_size: usize,
75 pub audio_embed_dim: usize,
77 pub audio_samples_per_token: usize,
79 pub output_proj_dims: usize,
83 pub rms_norm_eps: f64,
84}
85
86impl Default for GemmaAudioConfig {
87 fn default() -> Self {
88 Self {
89 hidden_size: 640,
90 audio_embed_dim: 640,
91 audio_samples_per_token: 640,
92 output_proj_dims: 640,
93 rms_norm_eps: 1e-6,
94 }
95 }
96}
97
98#[derive(Debug, Clone, Deserialize, Default)]
101pub struct GemmaMultimodalConfig {
102 #[serde(default)]
103 pub vision: Option<GemmaVisionConfig>,
104 #[serde(default)]
105 pub audio: Option<GemmaAudioConfig>,
106 #[serde(default)]
107 pub image_token_id: Option<u32>,
108 #[serde(default)]
109 pub audio_token_id: Option<u32>,
110 #[serde(default)]
111 pub video_token_id: Option<u32>,
112 #[serde(default)]
113 pub boi_token_id: Option<u32>,
114 #[serde(default)]
115 pub eoi_token_id: Option<u32>,
116 #[serde(default)]
117 pub boa_token_id: Option<u32>,
118 #[serde(default)]
119 pub eoa_token_index: Option<u32>,
120}
121
122impl GemmaMultimodalConfig {
123 pub fn from_file(path: &Path) -> Result<Self> {
128 let data = std::fs::read_to_string(path)?;
129 Self::parse_json(&data)
130 }
131
132 pub fn parse_json(raw: &str) -> Result<Self> {
134 raw.parse()
135 }
136
137 pub fn has_vision(&self) -> bool {
138 self.vision.is_some()
139 }
140 pub fn has_audio(&self) -> bool {
141 self.audio.is_some()
142 }
143}
144
145impl FromStr for GemmaMultimodalConfig {
146 type Err = anyhow::Error;
147
148 fn from_str(raw: &str) -> Result<Self, Self::Err> {
149 let value: serde_json::Value = serde_json::from_str(raw)?;
150 let vision = value
151 .get("vision_config")
152 .filter(|v| v.is_object())
153 .map(|v| serde_json::from_value::<GemmaVisionConfig>(v.clone()))
154 .transpose()?;
155 let audio = value
156 .get("audio_config")
157 .filter(|v| v.is_object())
158 .map(|v| serde_json::from_value::<GemmaAudioConfig>(v.clone()))
159 .transpose()?;
160 let pick_u32 = |k: &str| value.get(k).and_then(|v| v.as_u64()).map(|x| x as u32);
161 Ok(Self {
162 vision,
163 audio,
164 image_token_id: pick_u32("image_token_id"),
165 audio_token_id: pick_u32("audio_token_id"),
166 video_token_id: pick_u32("video_token_id"),
167 boi_token_id: pick_u32("boi_token_id"),
168 eoi_token_id: pick_u32("eoi_token_id"),
169 boa_token_id: pick_u32("boa_token_id"),
170 eoa_token_index: pick_u32("eoa_token_index"),
171 })
172 }
173}
174
175pub fn build_vision_projection_hir(
200 hir: &mut HirModule,
201 inputs: VisionProjectionInputs,
202 cfg: &GemmaVisionConfig,
203) -> Result<HirNodeId> {
204 let mm_embed_dim = cfg.mm_embed_dim;
215 let normed = {
216 let mut gb = HirMut::new(hir);
217 let projected = gb.mm(inputs.patches, inputs.embed_w);
218 let with_pos = gb.add(projected, inputs.pos_embed);
219 let gamma = gb.add(inputs.ones, inputs.norm_w);
220 gb.rms_norm(with_pos, gamma, inputs.zero_beta, cfg.rms_norm_eps as f32)
221 };
223 let mut gb = HirMut::new(hir);
224 let normed_t = gb.transpose_(normed, vec![0, 2, 1]);
226 let soft = gb.mm(normed_t, inputs.soft_token_w);
228 let soft_t = gb.transpose_(soft, vec![0, 2, 1]);
230 let out = gb.mm(soft_t, inputs.lm_proj_w);
233 let _ = mm_embed_dim;
234 Ok(out)
235}
236
237#[derive(Debug, Clone, Copy)]
242pub struct VisionProjectionInputs {
243 pub patches: HirNodeId,
244 pub embed_w: HirNodeId,
245 pub pos_embed: HirNodeId,
246 pub norm_w: HirNodeId,
247 pub ones: HirNodeId,
248 pub zero_beta: HirNodeId,
249 pub soft_token_w: HirNodeId,
250 pub lm_proj_w: HirNodeId,
251}
252
253#[derive(Debug, Clone, Copy)]
260pub struct VisionProjectionLearnedQueriesInputs {
261 pub patches: HirNodeId,
263 pub embed_w: HirNodeId,
265 pub pos_embed: HirNodeId,
267 pub norm_w: HirNodeId,
269 pub ones: HirNodeId,
271 pub zero_beta: HirNodeId,
273 pub queries: HirNodeId,
275 pub k_proj: HirNodeId,
277 pub v_proj: HirNodeId,
279 pub lm_proj_w: HirNodeId,
282}
283
284pub fn build_vision_projection_learned_queries_hir(
302 hir: &mut HirModule,
303 inputs: VisionProjectionLearnedQueriesInputs,
304 cfg: &GemmaVisionConfig,
305) -> Result<HirNodeId> {
306 let normed = {
308 let mut gb = HirMut::new(hir);
309 let projected = gb.mm(inputs.patches, inputs.embed_w);
310 let with_pos = gb.add(projected, inputs.pos_embed);
311 let gamma = gb.add(inputs.ones, inputs.norm_w);
312 gb.rms_norm(with_pos, gamma, inputs.zero_beta, cfg.rms_norm_eps as f32)
313 };
314 let mut gb = HirMut::new(hir);
315 let k = gb.mm(normed, inputs.k_proj);
317 let v = gb.mm(normed, inputs.v_proj);
318 use rlx_ir::Op;
328 let q_shape = gb.shape(inputs.queries).clone();
329 let k_shape = gb.shape(k).clone();
330 let b = q_shape.dim(0).unwrap_static();
332 let _ = b;
333 let attn_shape = q_shape.clone();
337 let attn = gb.0.mir(
338 Op::Attention {
339 num_heads: 1,
340 head_dim: cfg.mm_embed_dim,
341 mask_kind: rlx_ir::op::MaskKind::None,
342 score_scale: None,
343 attn_logit_softcap: None,
344 },
345 vec![inputs.queries, k, v],
346 attn_shape,
347 );
348 let _ = k_shape;
349 let out = gb.mm(attn, inputs.lm_proj_w);
351 Ok(out)
352}
353
354pub fn build_vision_projection_learned_queries_graph(
357 batch: usize,
358 num_patches: usize,
359 cfg: &GemmaVisionConfig,
360) -> Result<ProjectionGraph> {
361 let mut hir = HirModule::new("gemma_vision_projector_lq");
362 let patch_features = cfg.patch_size * cfg.patch_size * 3;
363 let patches = hir.input(
364 "patches",
365 Shape::new(&[batch, num_patches, patch_features], DType::F32),
366 );
367 let embed_w = hir.param(
368 "vision_tower.embed.weight",
369 Shape::new(&[patch_features, cfg.mm_embed_dim], DType::F32),
370 );
371 let pos_embed = hir.param(
372 "vision_tower.pos_embed.weight",
373 Shape::new(&[num_patches, cfg.mm_embed_dim], DType::F32),
374 );
375 let norm_w = hir.param(
376 "vision_tower.norm.weight",
377 Shape::new(&[cfg.mm_embed_dim], DType::F32),
378 );
379 let ones = hir.param(
380 "vision_tower.ones",
381 Shape::new(&[cfg.mm_embed_dim], DType::F32),
382 );
383 let zero_beta = hir.param(
384 "vision_tower.zero_beta",
385 Shape::new(&[cfg.mm_embed_dim], DType::F32),
386 );
387 let queries = hir.param(
388 "vision_tower.queries.weight",
389 Shape::new(&[batch, cfg.num_soft_tokens, cfg.mm_embed_dim], DType::F32),
390 );
391 let k_proj = hir.param(
392 "vision_tower.k_proj.weight",
393 Shape::new(&[cfg.mm_embed_dim, cfg.mm_embed_dim], DType::F32),
394 );
395 let v_proj = hir.param(
396 "vision_tower.v_proj.weight",
397 Shape::new(&[cfg.mm_embed_dim, cfg.mm_embed_dim], DType::F32),
398 );
399 let lm_proj_w = hir.param(
400 "vision_tower.lm_proj.weight",
401 Shape::new(&[cfg.mm_embed_dim, cfg.output_proj_dims], DType::F32),
402 );
403 let inputs = VisionProjectionLearnedQueriesInputs {
404 patches,
405 embed_w,
406 pos_embed,
407 norm_w,
408 ones,
409 zero_beta,
410 queries,
411 k_proj,
412 v_proj,
413 lm_proj_w,
414 };
415 let output = build_vision_projection_learned_queries_hir(&mut hir, inputs, cfg)?;
416 hir.set_outputs(vec![output]);
417 Ok(ProjectionGraph {
418 hir,
419 output,
420 input_keys: vec!["patches".into()],
421 })
422}
423
424pub fn build_audio_projection_hir(
434 hir: &mut HirModule,
435 inputs: AudioProjectionInputs,
436 cfg: &GemmaAudioConfig,
437) -> Result<HirNodeId> {
438 let mut gb = HirMut::new(hir);
439 let projected = gb.mm(inputs.frames, inputs.embed_w);
440 let gamma = gb.add(inputs.ones, inputs.norm_w);
441 let normed = gb.rms_norm(projected, gamma, inputs.zero_beta, cfg.rms_norm_eps as f32);
442 let out = gb.mm(normed, inputs.lm_proj_w);
443 Ok(out)
444}
445
446#[derive(Debug, Clone, Copy)]
448pub struct AudioProjectionInputs {
449 pub frames: HirNodeId,
450 pub embed_w: HirNodeId,
451 pub norm_w: HirNodeId,
452 pub ones: HirNodeId,
453 pub zero_beta: HirNodeId,
454 pub lm_proj_w: HirNodeId,
455}
456
457#[derive(Debug)]
463pub struct ProjectionGraph {
464 pub hir: HirModule,
465 pub output: HirNodeId,
467 pub input_keys: Vec<String>,
470}
471
472pub fn build_vision_projection_graph(
478 batch: usize,
479 num_patches: usize,
480 cfg: &GemmaVisionConfig,
481) -> Result<ProjectionGraph> {
482 let mut hir = HirModule::new("gemma_vision_projector");
483 let patch_features = cfg.patch_size * cfg.patch_size * 3;
484 let patches = hir.input(
485 "patches",
486 Shape::new(&[batch, num_patches, patch_features], DType::F32),
487 );
488 let embed_w = hir.param(
489 "vision_tower.embed.weight",
490 Shape::new(&[patch_features, cfg.mm_embed_dim], DType::F32),
491 );
492 let pos_embed = hir.param(
493 "vision_tower.pos_embed.weight",
494 Shape::new(&[num_patches, cfg.mm_embed_dim], DType::F32),
495 );
496 let norm_w = hir.param(
497 "vision_tower.norm.weight",
498 Shape::new(&[cfg.mm_embed_dim], DType::F32),
499 );
500 let ones = hir.param(
501 "vision_tower.ones",
502 Shape::new(&[cfg.mm_embed_dim], DType::F32),
503 );
504 let zero_beta = hir.param(
505 "vision_tower.zero_beta",
506 Shape::new(&[cfg.mm_embed_dim], DType::F32),
507 );
508 let soft_token_w = hir.param(
509 "vision_tower.soft_token.weight",
510 Shape::new(&[num_patches, cfg.num_soft_tokens], DType::F32),
513 );
514 let lm_proj_w = hir.param(
515 "vision_tower.lm_proj.weight",
516 Shape::new(&[cfg.mm_embed_dim, cfg.output_proj_dims], DType::F32),
517 );
518 let inputs = VisionProjectionInputs {
519 patches,
520 embed_w,
521 pos_embed,
522 norm_w,
523 ones,
524 zero_beta,
525 soft_token_w,
526 lm_proj_w,
527 };
528 let output = build_vision_projection_hir(&mut hir, inputs, cfg)?;
529 hir.set_outputs(vec![output]);
530 Ok(ProjectionGraph {
531 hir,
532 output,
533 input_keys: vec!["patches".into()],
534 })
535}
536
537pub fn build_audio_projection_graph(
540 batch: usize,
541 num_frames: usize,
542 cfg: &GemmaAudioConfig,
543 lm_hidden: usize,
544) -> Result<ProjectionGraph> {
545 let mut hir = HirModule::new("gemma_audio_projector");
546 let frames = hir.input(
547 "frames",
548 Shape::new(
549 &[batch, num_frames, cfg.audio_samples_per_token],
550 DType::F32,
551 ),
552 );
553 let embed_w = hir.param(
554 "audio_tower.embed.weight",
555 Shape::new(
556 &[cfg.audio_samples_per_token, cfg.audio_embed_dim],
557 DType::F32,
558 ),
559 );
560 let norm_w = hir.param(
561 "audio_tower.norm.weight",
562 Shape::new(&[cfg.audio_embed_dim], DType::F32),
563 );
564 let ones = hir.param(
565 "audio_tower.ones",
566 Shape::new(&[cfg.audio_embed_dim], DType::F32),
567 );
568 let zero_beta = hir.param(
569 "audio_tower.zero_beta",
570 Shape::new(&[cfg.audio_embed_dim], DType::F32),
571 );
572 let lm_proj_w = hir.param(
573 "audio_tower.lm_proj.weight",
574 Shape::new(&[cfg.audio_embed_dim, lm_hidden], DType::F32),
575 );
576 let inputs = AudioProjectionInputs {
577 frames,
578 embed_w,
579 norm_w,
580 ones,
581 zero_beta,
582 lm_proj_w,
583 };
584 let output = build_audio_projection_hir(&mut hir, inputs, cfg)?;
585 hir.set_outputs(vec![output]);
586 Ok(ProjectionGraph {
587 hir,
588 output,
589 input_keys: vec!["frames".into()],
590 })
591}
592
593#[derive(Debug, Clone, Copy)]
601pub struct ImageNormalize {
602 pub mean: [f32; 3],
603 pub std: [f32; 3],
604}
605
606impl ImageNormalize {
607 pub const fn unit() -> Self {
609 Self {
610 mean: [0.0; 3],
611 std: [1.0; 3],
612 }
613 }
614
615 pub const fn imagenet() -> Self {
618 Self {
619 mean: [0.485, 0.456, 0.406],
620 std: [0.229, 0.224, 0.225],
621 }
622 }
623
624 pub const fn clip() -> Self {
626 Self {
627 mean: [0.48145466, 0.4578275, 0.40821073],
628 std: [0.26862954, 0.261_302_6, 0.275_777_1],
629 }
630 }
631}
632
633impl Default for ImageNormalize {
634 fn default() -> Self {
635 Self::imagenet()
636 }
637}
638
639pub fn extract_image_patches(
649 rgb: &[u8],
650 width: usize,
651 height: usize,
652 patch_size: usize,
653) -> Result<Vec<f32>> {
654 extract_image_patches_normalized(rgb, width, height, patch_size, ImageNormalize::unit())
655}
656
657pub fn extract_image_patches_normalized(
660 rgb: &[u8],
661 width: usize,
662 height: usize,
663 patch_size: usize,
664 norm: ImageNormalize,
665) -> Result<Vec<f32>> {
666 if rgb.len() != width * height * 3 {
667 bail!(
668 "image buffer is {} bytes but {}x{}x3 = {}",
669 rgb.len(),
670 width,
671 height,
672 width * height * 3,
673 );
674 }
675 if patch_size == 0 {
676 bail!("patch_size must be > 0");
677 }
678 let patch_cols = width / patch_size;
679 let patch_rows = height / patch_size;
680 let num_patches = patch_rows * patch_cols;
681 let per_patch = patch_size * patch_size * 3;
682 let mut out = vec![0f32; num_patches * per_patch];
683 let row_stride_bytes = width * 3;
684 let inv = 1.0_f32 / 255.0;
687 let scale = [inv / norm.std[0], inv / norm.std[1], inv / norm.std[2]];
688 let bias = [
689 -norm.mean[0] / norm.std[0],
690 -norm.mean[1] / norm.std[1],
691 -norm.mean[2] / norm.std[2],
692 ];
693 for pr in 0..patch_rows {
694 let pr_base_y = pr * patch_size;
695 for pc in 0..patch_cols {
696 let patch_index = pr * patch_cols + pc;
697 let dst_base = patch_index * per_patch;
698 let pc_base_x = pc * patch_size;
699 for py in 0..patch_size {
700 let src_row_off = (pr_base_y + py) * row_stride_bytes + pc_base_x * 3;
701 let dst_row_off = dst_base + py * patch_size * 3;
702 let src = &rgb[src_row_off..src_row_off + patch_size * 3];
706 let dst = &mut out[dst_row_off..dst_row_off + patch_size * 3];
707 for px in 0..patch_size {
708 let s = px * 3;
709 dst[s] = src[s] as f32 * scale[0] + bias[0];
710 dst[s + 1] = src[s + 1] as f32 * scale[1] + bias[1];
711 dst[s + 2] = src[s + 2] as f32 * scale[2] + bias[2];
712 }
713 }
714 }
715 }
716 Ok(out)
717}
718
719pub fn frame_audio_samples(samples: &[f32], samples_per_token: usize) -> Result<(Vec<f32>, usize)> {
724 if samples_per_token == 0 {
725 bail!("samples_per_token must be > 0");
726 }
727 let num_frames = samples.len().div_ceil(samples_per_token).max(1);
728 let mut out = vec![0f32; num_frames * samples_per_token];
729 let copy_len = samples.len().min(out.len());
730 out[..copy_len].copy_from_slice(&samples[..copy_len]);
731 Ok((out, num_frames))
732}
733
734pub fn load_image_patches(
745 path: impl AsRef<std::path::Path>,
746 patch_size: usize,
747 max_side_patches: usize,
748) -> Result<(Vec<f32>, usize, usize)> {
749 load_image_patches_normalized(
750 path,
751 patch_size,
752 max_side_patches,
753 ImageNormalize::imagenet(),
754 )
755}
756
757pub fn load_image_patches_normalized(
762 path: impl AsRef<std::path::Path>,
763 patch_size: usize,
764 max_side_patches: usize,
765 norm: ImageNormalize,
766) -> Result<(Vec<f32>, usize, usize)> {
767 let img = image::open(path.as_ref()).map_err(|e| anyhow!("decode {:?}: {e}", path.as_ref()))?;
768 let rgb = img.to_rgb8();
769 let (w, h) = rgb.dimensions();
770 let (w, h) = (w as usize, h as usize);
771 let cap_px = max_side_patches.max(1) * patch_size;
772 let target_w = (w.min(cap_px) / patch_size).max(1) * patch_size;
773 let target_h = (h.min(cap_px) / patch_size).max(1) * patch_size;
774 let resized = if (target_w, target_h) != (w, h) {
775 image::DynamicImage::ImageRgb8(rgb)
776 .resize_exact(
777 target_w as u32,
778 target_h as u32,
779 image::imageops::FilterType::Triangle,
780 )
781 .to_rgb8()
782 } else {
783 rgb
784 };
785 let patches = extract_image_patches_normalized(
786 resized.as_raw(),
787 resized.width() as usize,
788 resized.height() as usize,
789 patch_size,
790 norm,
791 )?;
792 Ok((patches, target_h / patch_size, target_w / patch_size))
793}
794
795const SAMPLE_RATE_GEMMA4_HZ: u32 = 16_000;
798
799pub fn load_wav_mono_16khz(path: impl AsRef<std::path::Path>) -> Result<Vec<f32>> {
803 let bytes =
804 std::fs::read(path.as_ref()).map_err(|e| anyhow!("read {:?}: {e}", path.as_ref()))?;
805 parse_wav_16khz_mono(&bytes)
806}
807
808pub fn parse_wav_16khz_mono(bytes: &[u8]) -> Result<Vec<f32>> {
810 let (channels, src_rate, samples) = parse_pcm16_wav(bytes)?;
811 let mono = if channels == 1 {
813 samples
814 } else {
815 let n = samples.len() / channels as usize;
816 let mut out = Vec::with_capacity(n);
817 for frame in 0..n {
818 let base = frame * channels as usize;
819 let mut sum = 0.0f32;
820 for c in 0..channels as usize {
821 sum += samples[base + c];
822 }
823 out.push(sum / channels as f32);
824 }
825 out
826 };
827 if src_rate == SAMPLE_RATE_GEMMA4_HZ {
828 Ok(mono)
829 } else {
830 Ok(resample_linear(&mono, src_rate, SAMPLE_RATE_GEMMA4_HZ))
831 }
832}
833
834pub fn resample_linear(samples: &[f32], src_rate: u32, dst_rate: u32) -> Vec<f32> {
838 if src_rate == dst_rate || samples.is_empty() {
839 return samples.to_vec();
840 }
841 let ratio = dst_rate as f64 / src_rate as f64;
842 let out_len = ((samples.len() as f64) * ratio).round() as usize;
843 if out_len == 0 {
844 return Vec::new();
845 }
846 let mut out = Vec::with_capacity(out_len);
847 let step = src_rate as f64 / dst_rate as f64;
848 for i in 0..out_len {
849 let pos = i as f64 * step;
850 let lo = pos.floor() as usize;
851 let hi = (lo + 1).min(samples.len() - 1);
852 let frac = (pos - lo as f64) as f32;
853 let a = samples[lo];
854 let b = samples[hi];
855 out.push(a + (b - a) * frac);
856 }
857 out
858}
859
860fn parse_pcm16_wav(bytes: &[u8]) -> Result<(u16, u32, Vec<f32>)> {
862 if bytes.len() < 44 || &bytes[0..4] != b"RIFF" || &bytes[8..12] != b"WAVE" {
863 bail!("not a RIFF/WAVE file");
864 }
865 let mut pos = 12usize;
866 let mut fmt: Option<(u16, u16, u32, u16)> = None;
867 let mut data_chunk: Option<&[u8]> = None;
868 while pos + 8 <= bytes.len() {
869 let chunk_id = &bytes[pos..pos + 4];
870 let chunk_size = u32::from_le_bytes([
871 bytes[pos + 4],
872 bytes[pos + 5],
873 bytes[pos + 6],
874 bytes[pos + 7],
875 ]) as usize;
876 pos += 8;
877 let chunk = &bytes[pos..pos + chunk_size.min(bytes.len() - pos)];
878 match chunk_id {
879 b"fmt " => {
880 if chunk.len() < 16 {
881 bail!("wav fmt chunk too small");
882 }
883 let audio_format = u16::from_le_bytes([chunk[0], chunk[1]]);
884 let channels = u16::from_le_bytes([chunk[2], chunk[3]]);
885 let sr = u32::from_le_bytes([chunk[4], chunk[5], chunk[6], chunk[7]]);
886 let bps = u16::from_le_bytes([chunk[14], chunk[15]]);
887 fmt = Some((audio_format, channels, sr, bps));
888 }
889 b"data" => data_chunk = Some(chunk),
890 _ => {}
891 }
892 pos += chunk_size;
893 if chunk_size % 2 == 1 {
894 pos += 1; }
896 }
897 let (audio_format, channels, sr, bps) = fmt.ok_or_else(|| anyhow!("wav missing fmt chunk"))?;
898 if audio_format != 1 {
899 bail!("wav: only PCM supported (format={audio_format})");
900 }
901 if bps != 16 {
902 bail!("wav: only 16-bit PCM supported, got {bps}-bit");
903 }
904 let data = data_chunk.ok_or_else(|| anyhow!("wav missing data chunk"))?;
905 if data.len() % 2 != 0 {
906 bail!("wav data chunk not aligned to 2-byte sample width");
907 }
908 const SCALE: f32 = 1.0_f32 / 32_768.0;
911 let n = data.len() / 2;
912 let mut samples = Vec::with_capacity(n);
913 let mut i = 0;
919 while i + 8 <= n {
920 let base = i * 2;
922 for k in 0..8 {
923 let lo = data[base + k * 2];
924 let hi = data[base + k * 2 + 1];
925 samples.push(i16::from_le_bytes([lo, hi]) as f32 * SCALE);
926 }
927 i += 8;
928 }
929 while i < n {
930 let base = i * 2;
931 samples.push(i16::from_le_bytes([data[base], data[base + 1]]) as f32 * SCALE);
932 i += 1;
933 }
934 Ok((channels, sr, samples))
935}
936
937#[derive(Debug, Clone, Copy)]
941pub enum MediaSlot {
942 Image { count: usize },
945 Audio { count: usize },
947 Video { count: usize },
949}
950
951pub const IMAGE_MARKER_HF: &str = "<|image|>";
953pub const AUDIO_MARKER_HF: &str = "<|audio|>";
954pub const VIDEO_MARKER_HF: &str = "<|video|>";
955
956pub const IMAGE_MARKER: &str = "<image>";
958pub const AUDIO_MARKER: &str = "<audio>";
959pub const VIDEO_MARKER: &str = "<|video|>";
960
961#[derive(Clone, Copy)]
962enum MediaMarkerKind {
963 Image,
964 Audio,
965 Video,
966}
967
968fn next_media_marker(prompt: &str) -> Option<(usize, &'static str)> {
969 let markers: &[(&str, MediaMarkerKind)] = &[
970 (IMAGE_MARKER_HF, MediaMarkerKind::Image),
971 (IMAGE_MARKER, MediaMarkerKind::Image),
972 (AUDIO_MARKER_HF, MediaMarkerKind::Audio),
973 (AUDIO_MARKER, MediaMarkerKind::Audio),
974 (VIDEO_MARKER_HF, MediaMarkerKind::Video),
975 (VIDEO_MARKER, MediaMarkerKind::Video),
976 ];
977 let mut best: Option<(usize, &'static str)> = None;
978 for &(m, _) in markers {
979 if let Some(i) = prompt.find(m) {
980 if best.map(|(bi, _)| i < bi).unwrap_or(true) {
981 best = Some((i, m));
982 }
983 }
984 }
985 best
986}
987
988pub fn tokenize_with_media<F>(
1006 prompt: &str,
1007 slots: &[MediaSlot],
1008 cfg: &GemmaMultimodalConfig,
1009 mut encode_fn: F,
1010) -> Result<Vec<u32>>
1011where
1012 F: FnMut(&str) -> Result<Vec<u32>>,
1013{
1014 let mut text_chunks: Vec<Vec<u32>> = Vec::with_capacity(slots.len() + 1);
1019 let mut cursor = 0usize;
1020 let mut markers_seen = 0usize;
1021 let bytes = prompt.as_bytes();
1022 while cursor <= bytes.len() {
1023 let remainder = &prompt[cursor..];
1024 let next = next_media_marker(remainder);
1025 match next {
1026 Some((rel, marker)) => {
1027 let chunk = &remainder[..rel];
1028 text_chunks.push(encode_fn(chunk)?);
1029 cursor += rel + marker.len();
1030 markers_seen += 1;
1031 }
1032 None => {
1033 text_chunks.push(encode_fn(remainder)?);
1035 break;
1036 }
1037 }
1038 }
1039 if markers_seen != slots.len() {
1040 bail!(
1041 "prompt has {markers_seen} media markers but {} slot(s) supplied",
1042 slots.len(),
1043 );
1044 }
1045 expand_media_placeholders(&text_chunks, slots, cfg)
1046}
1047
1048pub fn expand_media_placeholders(
1059 text_chunks: &[Vec<u32>],
1060 slots: &[MediaSlot],
1061 cfg: &GemmaMultimodalConfig,
1062) -> Result<Vec<u32>> {
1063 if text_chunks.len() != slots.len() + 1 {
1064 bail!(
1065 "text_chunks ({}) must equal slots ({}) + 1",
1066 text_chunks.len(),
1067 slots.len(),
1068 );
1069 }
1070 let mut out: Vec<u32> =
1071 Vec::with_capacity(text_chunks.iter().map(|c| c.len()).sum::<usize>() + slots.len() * 16);
1072 for (i, chunk) in text_chunks.iter().enumerate() {
1073 out.extend_from_slice(chunk);
1074 if i < slots.len() {
1075 match slots[i] {
1076 MediaSlot::Image { count } => {
1077 let token = cfg.image_token_id.ok_or_else(|| {
1078 anyhow!("image slot requested but image_token_id is unset")
1079 })?;
1080 if let Some(boi) = cfg.boi_token_id {
1081 out.push(boi);
1082 }
1083 for _ in 0..count {
1084 out.push(token);
1085 }
1086 if let Some(eoi) = cfg.eoi_token_id {
1087 out.push(eoi);
1088 }
1089 }
1090 MediaSlot::Audio { count } => {
1091 let token = cfg.audio_token_id.ok_or_else(|| {
1092 anyhow!("audio slot requested but audio_token_id is unset")
1093 })?;
1094 if let Some(boa) = cfg.boa_token_id {
1095 out.push(boa);
1096 }
1097 for _ in 0..count {
1098 out.push(token);
1099 }
1100 if let Some(eoa) = cfg.eoa_token_index {
1101 out.push(eoa);
1102 }
1103 }
1104 MediaSlot::Video { count } => {
1105 let token = cfg.video_token_id.ok_or_else(|| {
1106 anyhow!("video slot requested but video_token_id is unset")
1107 })?;
1108 if let Some(boi) = cfg.boi_token_id {
1109 out.push(boi);
1110 }
1111 for _ in 0..count {
1112 out.push(token);
1113 }
1114 if let Some(eoi) = cfg.eoi_token_id {
1115 out.push(eoi);
1116 }
1117 }
1118 }
1119 }
1120 }
1121 Ok(out)
1122}
1123
1124pub fn fuse_multimodal_embeddings(
1139 text_embeds: &mut [f32],
1140 token_ids: &[u32],
1141 hidden: usize,
1142 cfg: &GemmaMultimodalConfig,
1143 image_embeds: &[f32],
1144 audio_embeds: &[f32],
1145 video_embeds: &[f32],
1146) -> Result<()> {
1147 if text_embeds.len() != token_ids.len() * hidden {
1148 bail!(
1149 "text_embeds {} != tokens {} * hidden {}",
1150 text_embeds.len(),
1151 token_ids.len(),
1152 hidden,
1153 );
1154 }
1155 let mut img_cursor = 0usize;
1156 let mut aud_cursor = 0usize;
1157 let mut vid_cursor = 0usize;
1158 for (i, &tok) in token_ids.iter().enumerate() {
1159 let dst = &mut text_embeds[i * hidden..(i + 1) * hidden];
1160 if Some(tok) == cfg.image_token_id {
1161 let src = image_embeds
1162 .get(img_cursor * hidden..(img_cursor + 1) * hidden)
1163 .ok_or_else(|| {
1164 anyhow!(
1165 "image_embeds exhausted at token {i}: need {} rows, have {}",
1166 img_cursor + 1,
1167 image_embeds.len() / hidden,
1168 )
1169 })?;
1170 dst.copy_from_slice(src);
1171 img_cursor += 1;
1172 } else if Some(tok) == cfg.video_token_id {
1173 let src = video_embeds
1174 .get(vid_cursor * hidden..(vid_cursor + 1) * hidden)
1175 .ok_or_else(|| {
1176 anyhow!(
1177 "video_embeds exhausted at token {i}: need {} rows, have {}",
1178 vid_cursor + 1,
1179 video_embeds.len() / hidden,
1180 )
1181 })?;
1182 dst.copy_from_slice(src);
1183 vid_cursor += 1;
1184 } else if Some(tok) == cfg.audio_token_id {
1185 let src = audio_embeds
1186 .get(aud_cursor * hidden..(aud_cursor + 1) * hidden)
1187 .ok_or_else(|| {
1188 anyhow!(
1189 "audio_embeds exhausted at token {i}: need {} rows, have {}",
1190 aud_cursor + 1,
1191 audio_embeds.len() / hidden,
1192 )
1193 })?;
1194 dst.copy_from_slice(src);
1195 aud_cursor += 1;
1196 }
1197 }
1198 Ok(())
1199}
1200
1201#[cfg(test)]
1202mod tests {
1203 use super::*;
1204
1205 const GEMMA_4_12B_FULL_CONFIG: &str = r#"{
1206 "model_type": "gemma4_unified",
1207 "audio_token_id": 258881,
1208 "image_token_id": 258880,
1209 "video_token_id": 258884,
1210 "boi_token_id": 255999,
1211 "eoi_token_id": 258882,
1212 "boa_token_id": 256000,
1213 "eoa_token_index": 258883,
1214 "audio_config": {
1215 "audio_embed_dim": 640,
1216 "audio_samples_per_token": 640,
1217 "hidden_size": 640,
1218 "output_proj_dims": 640,
1219 "rms_norm_eps": 1e-6
1220 },
1221 "vision_config": {
1222 "mm_embed_dim": 3840,
1223 "mm_posemb_size": 1120,
1224 "model_patch_size": 48,
1225 "num_soft_tokens": 280,
1226 "output_proj_dims": 3840,
1227 "patch_size": 16,
1228 "pooling_kernel_size": 3,
1229 "rms_norm_eps": 1e-6
1230 }
1231 }"#;
1232
1233 #[test]
1234 fn multimodal_config_parses_unified_layout() {
1235 let cfg = GemmaMultimodalConfig::parse_json(GEMMA_4_12B_FULL_CONFIG).unwrap();
1236 let vision = cfg.vision.as_ref().unwrap();
1237 let audio = cfg.audio.as_ref().unwrap();
1238 assert_eq!(vision.patch_size, 16);
1239 assert_eq!(vision.model_patch_size, 48);
1240 assert_eq!(vision.mm_embed_dim, 3840);
1241 assert_eq!(vision.num_soft_tokens, 280);
1242 assert_eq!(vision.output_proj_dims, 3840);
1243 assert_eq!(vision.pooling_kernel_size, 3);
1244 assert_eq!(audio.audio_samples_per_token, 640);
1245 assert_eq!(audio.audio_embed_dim, 640);
1246 assert_eq!(audio.output_proj_dims, 640);
1247 assert_eq!(cfg.image_token_id, Some(258_880));
1248 assert_eq!(cfg.audio_token_id, Some(258_881));
1249 assert_eq!(cfg.video_token_id, Some(258_884));
1250 }
1251
1252 #[test]
1253 fn fuse_replaces_only_placeholder_rows() {
1254 let cfg = GemmaMultimodalConfig {
1255 image_token_id: Some(100),
1256 audio_token_id: Some(200),
1257 ..Default::default()
1258 };
1259 let hidden = 4;
1260 let mut text = vec![
1261 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0, ];
1270 let ids = [42, 100, 200, 43];
1271 let img = vec![7.0, 7.0, 7.0, 7.0];
1272 let aud = vec![9.0, 9.0, 9.0, 9.0];
1273 fuse_multimodal_embeddings(&mut text, &ids, hidden, &cfg, &img, &aud, &[]).unwrap();
1274 assert_eq!(&text[0..4], &[1.0, 1.0, 1.0, 1.0]);
1275 assert_eq!(&text[4..8], &[7.0, 7.0, 7.0, 7.0]);
1276 assert_eq!(&text[8..12], &[9.0, 9.0, 9.0, 9.0]);
1277 assert_eq!(&text[12..16], &[2.0, 2.0, 2.0, 2.0]);
1278 }
1279
1280 #[test]
1281 fn fuse_errors_when_media_runs_out() {
1282 let cfg = GemmaMultimodalConfig {
1283 image_token_id: Some(100),
1284 ..Default::default()
1285 };
1286 let mut text = vec![0.0; 8];
1287 let ids = [100, 100];
1288 let img = vec![1.0; 4]; let err = fuse_multimodal_embeddings(&mut text, &ids, 4, &cfg, &img, &[], &[]).unwrap_err();
1290 assert!(err.to_string().contains("image_embeds exhausted"));
1291 }
1292
1293 #[test]
1294 fn empty_config_is_no_op() {
1295 let cfg = GemmaMultimodalConfig::default();
1296 let mut text = vec![1.0, 2.0, 3.0, 4.0];
1297 let ids = [10, 20];
1298 fuse_multimodal_embeddings(&mut text, &ids, 2, &cfg, &[], &[], &[]).unwrap();
1299 assert_eq!(text, vec![1.0, 2.0, 3.0, 4.0]);
1300 }
1301
1302 #[test]
1303 fn extract_image_patches_shapes_match_expected_grid() {
1304 let rgb: Vec<u8> = (0..(4 * 4 * 3) as u8).collect();
1306 let out = extract_image_patches(&rgb, 4, 4, 2).unwrap();
1307 assert_eq!(out.len(), 4 * 12);
1308 assert!((out[0] - 0.0 / 255.0).abs() < 1e-6);
1311 assert!((out[1] - 1.0 / 255.0).abs() < 1e-6);
1312 assert!((out[2] - 2.0 / 255.0).abs() < 1e-6);
1313 assert!((out[3] - 3.0 / 255.0).abs() < 1e-6);
1314 }
1315
1316 #[test]
1317 fn extract_image_patches_truncates_partial_pixels() {
1318 let rgb = vec![0u8; 5 * 5 * 3];
1321 let out = extract_image_patches(&rgb, 5, 5, 2).unwrap();
1322 assert_eq!(out.len(), 4 * 12);
1323 }
1324
1325 #[test]
1326 fn extract_image_patches_rejects_size_mismatch() {
1327 let rgb = vec![0u8; 4 * 4 * 3 - 1];
1328 assert!(extract_image_patches(&rgb, 4, 4, 2).is_err());
1329 }
1330
1331 #[test]
1332 fn frame_audio_samples_pads_last_frame() {
1333 let samples = vec![1.0f32; 1500]; let (out, n) = frame_audio_samples(&samples, 640).unwrap();
1335 assert_eq!(n, 3);
1336 assert_eq!(out.len(), 3 * 640);
1337 for &v in &out[1500..] {
1339 assert_eq!(v, 0.0);
1340 }
1341 for &v in &out[..1500] {
1343 assert_eq!(v, 1.0);
1344 }
1345 }
1346
1347 #[test]
1348 fn frame_audio_samples_minimum_one_frame() {
1349 let (out, n) = frame_audio_samples(&[], 640).unwrap();
1350 assert_eq!(n, 1);
1351 assert_eq!(out.len(), 640);
1352 }
1353
1354 #[test]
1355 fn expand_media_placeholders_brackets_and_inlines_tokens() {
1356 let cfg = GemmaMultimodalConfig {
1357 image_token_id: Some(900),
1358 boi_token_id: Some(800),
1359 eoi_token_id: Some(801),
1360 audio_token_id: Some(950),
1361 boa_token_id: Some(850),
1362 eoa_token_index: Some(851),
1363 ..Default::default()
1364 };
1365 let chunks = vec![vec![1, 2], vec![3], vec![4, 5]];
1366 let slots = vec![MediaSlot::Image { count: 4 }, MediaSlot::Audio { count: 2 }];
1367 let out = expand_media_placeholders(&chunks, &slots, &cfg).unwrap();
1368 assert_eq!(
1369 out,
1370 vec![
1371 1, 2, 800, 900, 900, 900, 900, 801, 3,
1372 850, 950, 950, 851, 4, 5
1373 ],
1374 );
1375 }
1376
1377 #[test]
1378 fn expand_media_placeholders_rejects_mismatched_chunks() {
1379 let cfg = GemmaMultimodalConfig {
1380 image_token_id: Some(900),
1381 ..Default::default()
1382 };
1383 let chunks = vec![vec![1]];
1384 let slots = vec![MediaSlot::Image { count: 4 }];
1385 assert!(expand_media_placeholders(&chunks, &slots, &cfg).is_err());
1386 }
1387
1388 #[test]
1389 fn standalone_projector_graphs_only_take_media_as_input() {
1390 let v_cfg = GemmaVisionConfig::default();
1393 let g = build_vision_projection_graph(1, 16, &v_cfg).unwrap();
1394 assert_eq!(g.input_keys, vec!["patches".to_string()]);
1395
1396 let a_cfg = GemmaAudioConfig::default();
1398 let g = build_audio_projection_graph(1, 8, &a_cfg, 3840).unwrap();
1399 assert_eq!(g.input_keys, vec!["frames".to_string()]);
1400 }
1401
1402 #[test]
1403 fn parse_wav_decodes_minimal_pcm16_mono() {
1404 let samples_i16: [i16; 4] = [0, 16_384, -16_384, 32_767];
1406 let mut bytes = Vec::new();
1407 bytes.extend_from_slice(b"RIFF");
1409 let total_size = 4 + (8 + 16) + (8 + samples_i16.len() * 2); bytes.extend_from_slice(&(total_size as u32).to_le_bytes());
1411 bytes.extend_from_slice(b"WAVE");
1412 bytes.extend_from_slice(b"fmt ");
1414 bytes.extend_from_slice(&16u32.to_le_bytes());
1415 bytes.extend_from_slice(&1u16.to_le_bytes()); bytes.extend_from_slice(&1u16.to_le_bytes()); bytes.extend_from_slice(&16_000u32.to_le_bytes()); bytes.extend_from_slice(&32_000u32.to_le_bytes()); bytes.extend_from_slice(&2u16.to_le_bytes()); bytes.extend_from_slice(&16u16.to_le_bytes()); bytes.extend_from_slice(b"data");
1423 bytes.extend_from_slice(&((samples_i16.len() * 2) as u32).to_le_bytes());
1424 for s in samples_i16 {
1425 bytes.extend_from_slice(&s.to_le_bytes());
1426 }
1427 let pcm = parse_wav_16khz_mono(&bytes).unwrap();
1428 assert_eq!(pcm.len(), 4);
1429 assert!((pcm[0] - 0.0).abs() < 1e-4);
1430 assert!((pcm[1] - 0.5).abs() < 1e-3);
1431 assert!((pcm[2] - (-0.5)).abs() < 1e-3);
1432 assert!((pcm[3] - 1.0).abs() < 1e-3);
1433 }
1434
1435 #[test]
1436 fn resample_linear_preserves_constants() {
1437 let src = vec![0.7f32; 1000];
1439 let out = resample_linear(&src, 48_000, 16_000);
1440 assert!((out.len() as i32 - 333).abs() <= 1);
1442 for &v in &out {
1443 assert!((v - 0.7).abs() < 1e-5);
1444 }
1445 }
1446
1447 #[test]
1448 fn tokenize_with_media_splits_and_expands() {
1449 let cfg = GemmaMultimodalConfig {
1450 image_token_id: Some(900),
1451 boi_token_id: Some(800),
1452 eoi_token_id: Some(801),
1453 audio_token_id: Some(950),
1454 boa_token_id: Some(850),
1455 eoa_token_index: Some(851),
1456 ..Default::default()
1457 };
1458 let encode = |s: &str| -> Result<Vec<u32>> { Ok(s.bytes().map(|b| b as u32).collect()) };
1461 let prompt = "hi <image> see <audio> bye";
1462 let slots = vec![MediaSlot::Image { count: 2 }, MediaSlot::Audio { count: 1 }];
1463 let out = tokenize_with_media(prompt, &slots, &cfg, encode).unwrap();
1464 let mut expected: Vec<u32> = b"hi ".iter().map(|b| *b as u32).collect();
1466 expected.extend([800, 900, 900, 801]);
1467 expected.extend(b" see ".iter().map(|b| *b as u32));
1468 expected.extend([850, 950, 851]);
1469 expected.extend(b" bye".iter().map(|b| *b as u32));
1470 assert_eq!(out, expected);
1471 }
1472
1473 #[test]
1474 fn tokenize_with_media_rejects_slot_marker_mismatch() {
1475 let cfg = GemmaMultimodalConfig {
1476 image_token_id: Some(900),
1477 ..Default::default()
1478 };
1479 let encode = |_: &str| -> Result<Vec<u32>> { Ok(vec![]) };
1480 let err = tokenize_with_media(
1482 "a <image> b <image> c",
1483 &[MediaSlot::Image { count: 1 }],
1484 &cfg,
1485 encode,
1486 )
1487 .unwrap_err();
1488 assert!(err.to_string().contains("media markers"));
1489 }
1490}