use anyhow::{Result, anyhow, bail};
use rlx_ir::hir::{HirModule, HirMut, HirNodeId};
use rlx_ir::{DType, HirGraphExt, Shape};
use serde::Deserialize;
use std::path::Path;
use std::str::FromStr;
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct GemmaVisionConfig {
pub patch_size: usize,
pub model_patch_size: usize,
pub mm_embed_dim: usize,
pub mm_posemb_size: usize,
pub num_soft_tokens: usize,
pub output_proj_dims: usize,
pub pooling_kernel_size: usize,
pub rms_norm_eps: f64,
}
impl Default for GemmaVisionConfig {
fn default() -> Self {
Self {
patch_size: 16,
model_patch_size: 48,
mm_embed_dim: 3840,
mm_posemb_size: 1120,
num_soft_tokens: 280,
output_proj_dims: 3840,
pooling_kernel_size: 3,
rms_norm_eps: 1e-6,
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct GemmaAudioConfig {
pub hidden_size: usize,
pub audio_embed_dim: usize,
pub audio_samples_per_token: usize,
pub output_proj_dims: usize,
pub rms_norm_eps: f64,
}
impl Default for GemmaAudioConfig {
fn default() -> Self {
Self {
hidden_size: 640,
audio_embed_dim: 640,
audio_samples_per_token: 640,
output_proj_dims: 640,
rms_norm_eps: 1e-6,
}
}
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct GemmaMultimodalConfig {
#[serde(default)]
pub vision: Option<GemmaVisionConfig>,
#[serde(default)]
pub audio: Option<GemmaAudioConfig>,
#[serde(default)]
pub image_token_id: Option<u32>,
#[serde(default)]
pub audio_token_id: Option<u32>,
#[serde(default)]
pub video_token_id: Option<u32>,
#[serde(default)]
pub boi_token_id: Option<u32>,
#[serde(default)]
pub eoi_token_id: Option<u32>,
#[serde(default)]
pub boa_token_id: Option<u32>,
#[serde(default)]
pub eoa_token_index: Option<u32>,
}
impl GemmaMultimodalConfig {
pub fn from_file(path: &Path) -> Result<Self> {
let data = std::fs::read_to_string(path)?;
Self::parse_json(&data)
}
pub fn parse_json(raw: &str) -> Result<Self> {
raw.parse()
}
pub fn has_vision(&self) -> bool {
self.vision.is_some()
}
pub fn has_audio(&self) -> bool {
self.audio.is_some()
}
}
impl FromStr for GemmaMultimodalConfig {
type Err = anyhow::Error;
fn from_str(raw: &str) -> Result<Self, Self::Err> {
let value: serde_json::Value = serde_json::from_str(raw)?;
let vision = value
.get("vision_config")
.filter(|v| v.is_object())
.map(|v| serde_json::from_value::<GemmaVisionConfig>(v.clone()))
.transpose()?;
let audio = value
.get("audio_config")
.filter(|v| v.is_object())
.map(|v| serde_json::from_value::<GemmaAudioConfig>(v.clone()))
.transpose()?;
let pick_u32 = |k: &str| value.get(k).and_then(|v| v.as_u64()).map(|x| x as u32);
Ok(Self {
vision,
audio,
image_token_id: pick_u32("image_token_id"),
audio_token_id: pick_u32("audio_token_id"),
video_token_id: pick_u32("video_token_id"),
boi_token_id: pick_u32("boi_token_id"),
eoi_token_id: pick_u32("eoi_token_id"),
boa_token_id: pick_u32("boa_token_id"),
eoa_token_index: pick_u32("eoa_token_index"),
})
}
}
pub fn build_vision_projection_hir(
hir: &mut HirModule,
inputs: VisionProjectionInputs,
cfg: &GemmaVisionConfig,
) -> Result<HirNodeId> {
let mm_embed_dim = cfg.mm_embed_dim;
let normed = {
let mut gb = HirMut::new(hir);
let projected = gb.mm(inputs.patches, inputs.embed_w);
let with_pos = gb.add(projected, inputs.pos_embed);
let gamma = gb.add(inputs.ones, inputs.norm_w);
gb.rms_norm(with_pos, gamma, inputs.zero_beta, cfg.rms_norm_eps as f32)
};
let mut gb = HirMut::new(hir);
let normed_t = gb.transpose_(normed, vec![0, 2, 1]);
let soft = gb.mm(normed_t, inputs.soft_token_w);
let soft_t = gb.transpose_(soft, vec![0, 2, 1]);
let out = gb.mm(soft_t, inputs.lm_proj_w);
let _ = mm_embed_dim;
Ok(out)
}
#[derive(Debug, Clone, Copy)]
pub struct VisionProjectionInputs {
pub patches: HirNodeId,
pub embed_w: HirNodeId,
pub pos_embed: HirNodeId,
pub norm_w: HirNodeId,
pub ones: HirNodeId,
pub zero_beta: HirNodeId,
pub soft_token_w: HirNodeId,
pub lm_proj_w: HirNodeId,
}
#[derive(Debug, Clone, Copy)]
pub struct VisionProjectionLearnedQueriesInputs {
pub patches: HirNodeId,
pub embed_w: HirNodeId,
pub pos_embed: HirNodeId,
pub norm_w: HirNodeId,
pub ones: HirNodeId,
pub zero_beta: HirNodeId,
pub queries: HirNodeId,
pub k_proj: HirNodeId,
pub v_proj: HirNodeId,
pub lm_proj_w: HirNodeId,
}
pub fn build_vision_projection_learned_queries_hir(
hir: &mut HirModule,
inputs: VisionProjectionLearnedQueriesInputs,
cfg: &GemmaVisionConfig,
) -> Result<HirNodeId> {
let normed = {
let mut gb = HirMut::new(hir);
let projected = gb.mm(inputs.patches, inputs.embed_w);
let with_pos = gb.add(projected, inputs.pos_embed);
let gamma = gb.add(inputs.ones, inputs.norm_w);
gb.rms_norm(with_pos, gamma, inputs.zero_beta, cfg.rms_norm_eps as f32)
};
let mut gb = HirMut::new(hir);
let k = gb.mm(normed, inputs.k_proj);
let v = gb.mm(normed, inputs.v_proj);
use rlx_ir::Op;
let q_shape = gb.shape(inputs.queries).clone();
let k_shape = gb.shape(k).clone();
let b = q_shape.dim(0).unwrap_static();
let _ = b;
let attn_shape = q_shape.clone();
let attn = gb.0.mir(
Op::Attention {
num_heads: 1,
head_dim: cfg.mm_embed_dim,
mask_kind: rlx_ir::op::MaskKind::None,
score_scale: None,
attn_logit_softcap: None,
},
vec![inputs.queries, k, v],
attn_shape,
);
let _ = k_shape;
let out = gb.mm(attn, inputs.lm_proj_w);
Ok(out)
}
pub fn build_vision_projection_learned_queries_graph(
batch: usize,
num_patches: usize,
cfg: &GemmaVisionConfig,
) -> Result<ProjectionGraph> {
let mut hir = HirModule::new("gemma_vision_projector_lq");
let patch_features = cfg.patch_size * cfg.patch_size * 3;
let patches = hir.input(
"patches",
Shape::new(&[batch, num_patches, patch_features], DType::F32),
);
let embed_w = hir.param(
"vision_tower.embed.weight",
Shape::new(&[patch_features, cfg.mm_embed_dim], DType::F32),
);
let pos_embed = hir.param(
"vision_tower.pos_embed.weight",
Shape::new(&[num_patches, cfg.mm_embed_dim], DType::F32),
);
let norm_w = hir.param(
"vision_tower.norm.weight",
Shape::new(&[cfg.mm_embed_dim], DType::F32),
);
let ones = hir.param(
"vision_tower.ones",
Shape::new(&[cfg.mm_embed_dim], DType::F32),
);
let zero_beta = hir.param(
"vision_tower.zero_beta",
Shape::new(&[cfg.mm_embed_dim], DType::F32),
);
let queries = hir.param(
"vision_tower.queries.weight",
Shape::new(&[batch, cfg.num_soft_tokens, cfg.mm_embed_dim], DType::F32),
);
let k_proj = hir.param(
"vision_tower.k_proj.weight",
Shape::new(&[cfg.mm_embed_dim, cfg.mm_embed_dim], DType::F32),
);
let v_proj = hir.param(
"vision_tower.v_proj.weight",
Shape::new(&[cfg.mm_embed_dim, cfg.mm_embed_dim], DType::F32),
);
let lm_proj_w = hir.param(
"vision_tower.lm_proj.weight",
Shape::new(&[cfg.mm_embed_dim, cfg.output_proj_dims], DType::F32),
);
let inputs = VisionProjectionLearnedQueriesInputs {
patches,
embed_w,
pos_embed,
norm_w,
ones,
zero_beta,
queries,
k_proj,
v_proj,
lm_proj_w,
};
let output = build_vision_projection_learned_queries_hir(&mut hir, inputs, cfg)?;
hir.set_outputs(vec![output]);
Ok(ProjectionGraph {
hir,
output,
input_keys: vec!["patches".into()],
})
}
pub fn build_audio_projection_hir(
hir: &mut HirModule,
inputs: AudioProjectionInputs,
cfg: &GemmaAudioConfig,
) -> Result<HirNodeId> {
let mut gb = HirMut::new(hir);
let projected = gb.mm(inputs.frames, inputs.embed_w);
let gamma = gb.add(inputs.ones, inputs.norm_w);
let normed = gb.rms_norm(projected, gamma, inputs.zero_beta, cfg.rms_norm_eps as f32);
let out = gb.mm(normed, inputs.lm_proj_w);
Ok(out)
}
#[derive(Debug, Clone, Copy)]
pub struct AudioProjectionInputs {
pub frames: HirNodeId,
pub embed_w: HirNodeId,
pub norm_w: HirNodeId,
pub ones: HirNodeId,
pub zero_beta: HirNodeId,
pub lm_proj_w: HirNodeId,
}
#[derive(Debug)]
pub struct ProjectionGraph {
pub hir: HirModule,
pub output: HirNodeId,
pub input_keys: Vec<String>,
}
pub fn build_vision_projection_graph(
batch: usize,
num_patches: usize,
cfg: &GemmaVisionConfig,
) -> Result<ProjectionGraph> {
let mut hir = HirModule::new("gemma_vision_projector");
let patch_features = cfg.patch_size * cfg.patch_size * 3;
let patches = hir.input(
"patches",
Shape::new(&[batch, num_patches, patch_features], DType::F32),
);
let embed_w = hir.param(
"vision_tower.embed.weight",
Shape::new(&[patch_features, cfg.mm_embed_dim], DType::F32),
);
let pos_embed = hir.param(
"vision_tower.pos_embed.weight",
Shape::new(&[num_patches, cfg.mm_embed_dim], DType::F32),
);
let norm_w = hir.param(
"vision_tower.norm.weight",
Shape::new(&[cfg.mm_embed_dim], DType::F32),
);
let ones = hir.param(
"vision_tower.ones",
Shape::new(&[cfg.mm_embed_dim], DType::F32),
);
let zero_beta = hir.param(
"vision_tower.zero_beta",
Shape::new(&[cfg.mm_embed_dim], DType::F32),
);
let soft_token_w = hir.param(
"vision_tower.soft_token.weight",
Shape::new(&[num_patches, cfg.num_soft_tokens], DType::F32),
);
let lm_proj_w = hir.param(
"vision_tower.lm_proj.weight",
Shape::new(&[cfg.mm_embed_dim, cfg.output_proj_dims], DType::F32),
);
let inputs = VisionProjectionInputs {
patches,
embed_w,
pos_embed,
norm_w,
ones,
zero_beta,
soft_token_w,
lm_proj_w,
};
let output = build_vision_projection_hir(&mut hir, inputs, cfg)?;
hir.set_outputs(vec![output]);
Ok(ProjectionGraph {
hir,
output,
input_keys: vec!["patches".into()],
})
}
pub fn build_audio_projection_graph(
batch: usize,
num_frames: usize,
cfg: &GemmaAudioConfig,
lm_hidden: usize,
) -> Result<ProjectionGraph> {
let mut hir = HirModule::new("gemma_audio_projector");
let frames = hir.input(
"frames",
Shape::new(
&[batch, num_frames, cfg.audio_samples_per_token],
DType::F32,
),
);
let embed_w = hir.param(
"audio_tower.embed.weight",
Shape::new(
&[cfg.audio_samples_per_token, cfg.audio_embed_dim],
DType::F32,
),
);
let norm_w = hir.param(
"audio_tower.norm.weight",
Shape::new(&[cfg.audio_embed_dim], DType::F32),
);
let ones = hir.param(
"audio_tower.ones",
Shape::new(&[cfg.audio_embed_dim], DType::F32),
);
let zero_beta = hir.param(
"audio_tower.zero_beta",
Shape::new(&[cfg.audio_embed_dim], DType::F32),
);
let lm_proj_w = hir.param(
"audio_tower.lm_proj.weight",
Shape::new(&[cfg.audio_embed_dim, lm_hidden], DType::F32),
);
let inputs = AudioProjectionInputs {
frames,
embed_w,
norm_w,
ones,
zero_beta,
lm_proj_w,
};
let output = build_audio_projection_hir(&mut hir, inputs, cfg)?;
hir.set_outputs(vec![output]);
Ok(ProjectionGraph {
hir,
output,
input_keys: vec!["frames".into()],
})
}
#[derive(Debug, Clone, Copy)]
pub struct ImageNormalize {
pub mean: [f32; 3],
pub std: [f32; 3],
}
impl ImageNormalize {
pub const fn unit() -> Self {
Self {
mean: [0.0; 3],
std: [1.0; 3],
}
}
pub const fn imagenet() -> Self {
Self {
mean: [0.485, 0.456, 0.406],
std: [0.229, 0.224, 0.225],
}
}
pub const fn clip() -> Self {
Self {
mean: [0.48145466, 0.4578275, 0.40821073],
std: [0.26862954, 0.261_302_6, 0.275_777_1],
}
}
}
impl Default for ImageNormalize {
fn default() -> Self {
Self::imagenet()
}
}
pub fn extract_image_patches(
rgb: &[u8],
width: usize,
height: usize,
patch_size: usize,
) -> Result<Vec<f32>> {
extract_image_patches_normalized(rgb, width, height, patch_size, ImageNormalize::unit())
}
pub fn extract_image_patches_normalized(
rgb: &[u8],
width: usize,
height: usize,
patch_size: usize,
norm: ImageNormalize,
) -> Result<Vec<f32>> {
if rgb.len() != width * height * 3 {
bail!(
"image buffer is {} bytes but {}x{}x3 = {}",
rgb.len(),
width,
height,
width * height * 3,
);
}
if patch_size == 0 {
bail!("patch_size must be > 0");
}
let patch_cols = width / patch_size;
let patch_rows = height / patch_size;
let num_patches = patch_rows * patch_cols;
let per_patch = patch_size * patch_size * 3;
let mut out = vec![0f32; num_patches * per_patch];
let row_stride_bytes = width * 3;
let inv = 1.0_f32 / 255.0;
let scale = [inv / norm.std[0], inv / norm.std[1], inv / norm.std[2]];
let bias = [
-norm.mean[0] / norm.std[0],
-norm.mean[1] / norm.std[1],
-norm.mean[2] / norm.std[2],
];
for pr in 0..patch_rows {
let pr_base_y = pr * patch_size;
for pc in 0..patch_cols {
let patch_index = pr * patch_cols + pc;
let dst_base = patch_index * per_patch;
let pc_base_x = pc * patch_size;
for py in 0..patch_size {
let src_row_off = (pr_base_y + py) * row_stride_bytes + pc_base_x * 3;
let dst_row_off = dst_base + py * patch_size * 3;
let src = &rgb[src_row_off..src_row_off + patch_size * 3];
let dst = &mut out[dst_row_off..dst_row_off + patch_size * 3];
for px in 0..patch_size {
let s = px * 3;
dst[s] = src[s] as f32 * scale[0] + bias[0];
dst[s + 1] = src[s + 1] as f32 * scale[1] + bias[1];
dst[s + 2] = src[s + 2] as f32 * scale[2] + bias[2];
}
}
}
}
Ok(out)
}
pub fn frame_audio_samples(samples: &[f32], samples_per_token: usize) -> Result<(Vec<f32>, usize)> {
if samples_per_token == 0 {
bail!("samples_per_token must be > 0");
}
let num_frames = samples.len().div_ceil(samples_per_token).max(1);
let mut out = vec![0f32; num_frames * samples_per_token];
let copy_len = samples.len().min(out.len());
out[..copy_len].copy_from_slice(&samples[..copy_len]);
Ok((out, num_frames))
}
pub fn load_image_patches(
path: impl AsRef<std::path::Path>,
patch_size: usize,
max_side_patches: usize,
) -> Result<(Vec<f32>, usize, usize)> {
load_image_patches_normalized(
path,
patch_size,
max_side_patches,
ImageNormalize::imagenet(),
)
}
pub fn load_image_patches_normalized(
path: impl AsRef<std::path::Path>,
patch_size: usize,
max_side_patches: usize,
norm: ImageNormalize,
) -> Result<(Vec<f32>, usize, usize)> {
let img = image::open(path.as_ref()).map_err(|e| anyhow!("decode {:?}: {e}", path.as_ref()))?;
let rgb = img.to_rgb8();
let (w, h) = rgb.dimensions();
let (w, h) = (w as usize, h as usize);
let cap_px = max_side_patches.max(1) * patch_size;
let target_w = (w.min(cap_px) / patch_size).max(1) * patch_size;
let target_h = (h.min(cap_px) / patch_size).max(1) * patch_size;
let resized = if (target_w, target_h) != (w, h) {
image::DynamicImage::ImageRgb8(rgb)
.resize_exact(
target_w as u32,
target_h as u32,
image::imageops::FilterType::Triangle,
)
.to_rgb8()
} else {
rgb
};
let patches = extract_image_patches_normalized(
resized.as_raw(),
resized.width() as usize,
resized.height() as usize,
patch_size,
norm,
)?;
Ok((patches, target_h / patch_size, target_w / patch_size))
}
const SAMPLE_RATE_GEMMA4_HZ: u32 = 16_000;
pub fn load_wav_mono_16khz(path: impl AsRef<std::path::Path>) -> Result<Vec<f32>> {
let bytes =
std::fs::read(path.as_ref()).map_err(|e| anyhow!("read {:?}: {e}", path.as_ref()))?;
parse_wav_16khz_mono(&bytes)
}
pub fn parse_wav_16khz_mono(bytes: &[u8]) -> Result<Vec<f32>> {
let (channels, src_rate, samples) = parse_pcm16_wav(bytes)?;
let mono = if channels == 1 {
samples
} else {
let n = samples.len() / channels as usize;
let mut out = Vec::with_capacity(n);
for frame in 0..n {
let base = frame * channels as usize;
let mut sum = 0.0f32;
for c in 0..channels as usize {
sum += samples[base + c];
}
out.push(sum / channels as f32);
}
out
};
if src_rate == SAMPLE_RATE_GEMMA4_HZ {
Ok(mono)
} else {
Ok(resample_linear(&mono, src_rate, SAMPLE_RATE_GEMMA4_HZ))
}
}
pub fn resample_linear(samples: &[f32], src_rate: u32, dst_rate: u32) -> Vec<f32> {
if src_rate == dst_rate || samples.is_empty() {
return samples.to_vec();
}
let ratio = dst_rate as f64 / src_rate as f64;
let out_len = ((samples.len() as f64) * ratio).round() as usize;
if out_len == 0 {
return Vec::new();
}
let mut out = Vec::with_capacity(out_len);
let step = src_rate as f64 / dst_rate as f64;
for i in 0..out_len {
let pos = i as f64 * step;
let lo = pos.floor() as usize;
let hi = (lo + 1).min(samples.len() - 1);
let frac = (pos - lo as f64) as f32;
let a = samples[lo];
let b = samples[hi];
out.push(a + (b - a) * frac);
}
out
}
fn parse_pcm16_wav(bytes: &[u8]) -> Result<(u16, u32, Vec<f32>)> {
if bytes.len() < 44 || &bytes[0..4] != b"RIFF" || &bytes[8..12] != b"WAVE" {
bail!("not a RIFF/WAVE file");
}
let mut pos = 12usize;
let mut fmt: Option<(u16, u16, u32, u16)> = None;
let mut data_chunk: Option<&[u8]> = None;
while pos + 8 <= bytes.len() {
let chunk_id = &bytes[pos..pos + 4];
let chunk_size = u32::from_le_bytes([
bytes[pos + 4],
bytes[pos + 5],
bytes[pos + 6],
bytes[pos + 7],
]) as usize;
pos += 8;
let chunk = &bytes[pos..pos + chunk_size.min(bytes.len() - pos)];
match chunk_id {
b"fmt " => {
if chunk.len() < 16 {
bail!("wav fmt chunk too small");
}
let audio_format = u16::from_le_bytes([chunk[0], chunk[1]]);
let channels = u16::from_le_bytes([chunk[2], chunk[3]]);
let sr = u32::from_le_bytes([chunk[4], chunk[5], chunk[6], chunk[7]]);
let bps = u16::from_le_bytes([chunk[14], chunk[15]]);
fmt = Some((audio_format, channels, sr, bps));
}
b"data" => data_chunk = Some(chunk),
_ => {}
}
pos += chunk_size;
if chunk_size % 2 == 1 {
pos += 1; }
}
let (audio_format, channels, sr, bps) = fmt.ok_or_else(|| anyhow!("wav missing fmt chunk"))?;
if audio_format != 1 {
bail!("wav: only PCM supported (format={audio_format})");
}
if bps != 16 {
bail!("wav: only 16-bit PCM supported, got {bps}-bit");
}
let data = data_chunk.ok_or_else(|| anyhow!("wav missing data chunk"))?;
if data.len() % 2 != 0 {
bail!("wav data chunk not aligned to 2-byte sample width");
}
const SCALE: f32 = 1.0_f32 / 32_768.0;
let n = data.len() / 2;
let mut samples = Vec::with_capacity(n);
let mut i = 0;
while i + 8 <= n {
let base = i * 2;
for k in 0..8 {
let lo = data[base + k * 2];
let hi = data[base + k * 2 + 1];
samples.push(i16::from_le_bytes([lo, hi]) as f32 * SCALE);
}
i += 8;
}
while i < n {
let base = i * 2;
samples.push(i16::from_le_bytes([data[base], data[base + 1]]) as f32 * SCALE);
i += 1;
}
Ok((channels, sr, samples))
}
#[derive(Debug, Clone, Copy)]
pub enum MediaSlot {
Image { count: usize },
Audio { count: usize },
Video { count: usize },
}
pub const IMAGE_MARKER_HF: &str = "<|image|>";
pub const AUDIO_MARKER_HF: &str = "<|audio|>";
pub const VIDEO_MARKER_HF: &str = "<|video|>";
pub const IMAGE_MARKER: &str = "<image>";
pub const AUDIO_MARKER: &str = "<audio>";
pub const VIDEO_MARKER: &str = "<|video|>";
#[derive(Clone, Copy)]
enum MediaMarkerKind {
Image,
Audio,
Video,
}
fn next_media_marker(prompt: &str) -> Option<(usize, &'static str)> {
let markers: &[(&str, MediaMarkerKind)] = &[
(IMAGE_MARKER_HF, MediaMarkerKind::Image),
(IMAGE_MARKER, MediaMarkerKind::Image),
(AUDIO_MARKER_HF, MediaMarkerKind::Audio),
(AUDIO_MARKER, MediaMarkerKind::Audio),
(VIDEO_MARKER_HF, MediaMarkerKind::Video),
(VIDEO_MARKER, MediaMarkerKind::Video),
];
let mut best: Option<(usize, &'static str)> = None;
for &(m, _) in markers {
if let Some(i) = prompt.find(m) {
if best.map(|(bi, _)| i < bi).unwrap_or(true) {
best = Some((i, m));
}
}
}
best
}
pub fn tokenize_with_media<F>(
prompt: &str,
slots: &[MediaSlot],
cfg: &GemmaMultimodalConfig,
mut encode_fn: F,
) -> Result<Vec<u32>>
where
F: FnMut(&str) -> Result<Vec<u32>>,
{
let mut text_chunks: Vec<Vec<u32>> = Vec::with_capacity(slots.len() + 1);
let mut cursor = 0usize;
let mut markers_seen = 0usize;
let bytes = prompt.as_bytes();
while cursor <= bytes.len() {
let remainder = &prompt[cursor..];
let next = next_media_marker(remainder);
match next {
Some((rel, marker)) => {
let chunk = &remainder[..rel];
text_chunks.push(encode_fn(chunk)?);
cursor += rel + marker.len();
markers_seen += 1;
}
None => {
text_chunks.push(encode_fn(remainder)?);
break;
}
}
}
if markers_seen != slots.len() {
bail!(
"prompt has {markers_seen} media markers but {} slot(s) supplied",
slots.len(),
);
}
expand_media_placeholders(&text_chunks, slots, cfg)
}
pub fn expand_media_placeholders(
text_chunks: &[Vec<u32>],
slots: &[MediaSlot],
cfg: &GemmaMultimodalConfig,
) -> Result<Vec<u32>> {
if text_chunks.len() != slots.len() + 1 {
bail!(
"text_chunks ({}) must equal slots ({}) + 1",
text_chunks.len(),
slots.len(),
);
}
let mut out: Vec<u32> =
Vec::with_capacity(text_chunks.iter().map(|c| c.len()).sum::<usize>() + slots.len() * 16);
for (i, chunk) in text_chunks.iter().enumerate() {
out.extend_from_slice(chunk);
if i < slots.len() {
match slots[i] {
MediaSlot::Image { count } => {
let token = cfg.image_token_id.ok_or_else(|| {
anyhow!("image slot requested but image_token_id is unset")
})?;
if let Some(boi) = cfg.boi_token_id {
out.push(boi);
}
for _ in 0..count {
out.push(token);
}
if let Some(eoi) = cfg.eoi_token_id {
out.push(eoi);
}
}
MediaSlot::Audio { count } => {
let token = cfg.audio_token_id.ok_or_else(|| {
anyhow!("audio slot requested but audio_token_id is unset")
})?;
if let Some(boa) = cfg.boa_token_id {
out.push(boa);
}
for _ in 0..count {
out.push(token);
}
if let Some(eoa) = cfg.eoa_token_index {
out.push(eoa);
}
}
MediaSlot::Video { count } => {
let token = cfg.video_token_id.ok_or_else(|| {
anyhow!("video slot requested but video_token_id is unset")
})?;
if let Some(boi) = cfg.boi_token_id {
out.push(boi);
}
for _ in 0..count {
out.push(token);
}
if let Some(eoi) = cfg.eoi_token_id {
out.push(eoi);
}
}
}
}
}
Ok(out)
}
pub fn fuse_multimodal_embeddings(
text_embeds: &mut [f32],
token_ids: &[u32],
hidden: usize,
cfg: &GemmaMultimodalConfig,
image_embeds: &[f32],
audio_embeds: &[f32],
video_embeds: &[f32],
) -> Result<()> {
if text_embeds.len() != token_ids.len() * hidden {
bail!(
"text_embeds {} != tokens {} * hidden {}",
text_embeds.len(),
token_ids.len(),
hidden,
);
}
let mut img_cursor = 0usize;
let mut aud_cursor = 0usize;
let mut vid_cursor = 0usize;
for (i, &tok) in token_ids.iter().enumerate() {
let dst = &mut text_embeds[i * hidden..(i + 1) * hidden];
if Some(tok) == cfg.image_token_id {
let src = image_embeds
.get(img_cursor * hidden..(img_cursor + 1) * hidden)
.ok_or_else(|| {
anyhow!(
"image_embeds exhausted at token {i}: need {} rows, have {}",
img_cursor + 1,
image_embeds.len() / hidden,
)
})?;
dst.copy_from_slice(src);
img_cursor += 1;
} else if Some(tok) == cfg.video_token_id {
let src = video_embeds
.get(vid_cursor * hidden..(vid_cursor + 1) * hidden)
.ok_or_else(|| {
anyhow!(
"video_embeds exhausted at token {i}: need {} rows, have {}",
vid_cursor + 1,
video_embeds.len() / hidden,
)
})?;
dst.copy_from_slice(src);
vid_cursor += 1;
} else if Some(tok) == cfg.audio_token_id {
let src = audio_embeds
.get(aud_cursor * hidden..(aud_cursor + 1) * hidden)
.ok_or_else(|| {
anyhow!(
"audio_embeds exhausted at token {i}: need {} rows, have {}",
aud_cursor + 1,
audio_embeds.len() / hidden,
)
})?;
dst.copy_from_slice(src);
aud_cursor += 1;
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
const GEMMA_4_12B_FULL_CONFIG: &str = r#"{
"model_type": "gemma4_unified",
"audio_token_id": 258881,
"image_token_id": 258880,
"video_token_id": 258884,
"boi_token_id": 255999,
"eoi_token_id": 258882,
"boa_token_id": 256000,
"eoa_token_index": 258883,
"audio_config": {
"audio_embed_dim": 640,
"audio_samples_per_token": 640,
"hidden_size": 640,
"output_proj_dims": 640,
"rms_norm_eps": 1e-6
},
"vision_config": {
"mm_embed_dim": 3840,
"mm_posemb_size": 1120,
"model_patch_size": 48,
"num_soft_tokens": 280,
"output_proj_dims": 3840,
"patch_size": 16,
"pooling_kernel_size": 3,
"rms_norm_eps": 1e-6
}
}"#;
#[test]
fn multimodal_config_parses_unified_layout() {
let cfg = GemmaMultimodalConfig::parse_json(GEMMA_4_12B_FULL_CONFIG).unwrap();
let vision = cfg.vision.as_ref().unwrap();
let audio = cfg.audio.as_ref().unwrap();
assert_eq!(vision.patch_size, 16);
assert_eq!(vision.model_patch_size, 48);
assert_eq!(vision.mm_embed_dim, 3840);
assert_eq!(vision.num_soft_tokens, 280);
assert_eq!(vision.output_proj_dims, 3840);
assert_eq!(vision.pooling_kernel_size, 3);
assert_eq!(audio.audio_samples_per_token, 640);
assert_eq!(audio.audio_embed_dim, 640);
assert_eq!(audio.output_proj_dims, 640);
assert_eq!(cfg.image_token_id, Some(258_880));
assert_eq!(cfg.audio_token_id, Some(258_881));
assert_eq!(cfg.video_token_id, Some(258_884));
}
#[test]
fn fuse_replaces_only_placeholder_rows() {
let cfg = GemmaMultimodalConfig {
image_token_id: Some(100),
audio_token_id: Some(200),
..Default::default()
};
let hidden = 4;
let mut text = vec![
1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0, ];
let ids = [42, 100, 200, 43];
let img = vec![7.0, 7.0, 7.0, 7.0];
let aud = vec![9.0, 9.0, 9.0, 9.0];
fuse_multimodal_embeddings(&mut text, &ids, hidden, &cfg, &img, &aud, &[]).unwrap();
assert_eq!(&text[0..4], &[1.0, 1.0, 1.0, 1.0]);
assert_eq!(&text[4..8], &[7.0, 7.0, 7.0, 7.0]);
assert_eq!(&text[8..12], &[9.0, 9.0, 9.0, 9.0]);
assert_eq!(&text[12..16], &[2.0, 2.0, 2.0, 2.0]);
}
#[test]
fn fuse_errors_when_media_runs_out() {
let cfg = GemmaMultimodalConfig {
image_token_id: Some(100),
..Default::default()
};
let mut text = vec![0.0; 8];
let ids = [100, 100];
let img = vec![1.0; 4]; let err = fuse_multimodal_embeddings(&mut text, &ids, 4, &cfg, &img, &[], &[]).unwrap_err();
assert!(err.to_string().contains("image_embeds exhausted"));
}
#[test]
fn empty_config_is_no_op() {
let cfg = GemmaMultimodalConfig::default();
let mut text = vec![1.0, 2.0, 3.0, 4.0];
let ids = [10, 20];
fuse_multimodal_embeddings(&mut text, &ids, 2, &cfg, &[], &[], &[]).unwrap();
assert_eq!(text, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn extract_image_patches_shapes_match_expected_grid() {
let rgb: Vec<u8> = (0..(4 * 4 * 3) as u8).collect();
let out = extract_image_patches(&rgb, 4, 4, 2).unwrap();
assert_eq!(out.len(), 4 * 12);
assert!((out[0] - 0.0 / 255.0).abs() < 1e-6);
assert!((out[1] - 1.0 / 255.0).abs() < 1e-6);
assert!((out[2] - 2.0 / 255.0).abs() < 1e-6);
assert!((out[3] - 3.0 / 255.0).abs() < 1e-6);
}
#[test]
fn extract_image_patches_truncates_partial_pixels() {
let rgb = vec![0u8; 5 * 5 * 3];
let out = extract_image_patches(&rgb, 5, 5, 2).unwrap();
assert_eq!(out.len(), 4 * 12);
}
#[test]
fn extract_image_patches_rejects_size_mismatch() {
let rgb = vec![0u8; 4 * 4 * 3 - 1];
assert!(extract_image_patches(&rgb, 4, 4, 2).is_err());
}
#[test]
fn frame_audio_samples_pads_last_frame() {
let samples = vec![1.0f32; 1500]; let (out, n) = frame_audio_samples(&samples, 640).unwrap();
assert_eq!(n, 3);
assert_eq!(out.len(), 3 * 640);
for &v in &out[1500..] {
assert_eq!(v, 0.0);
}
for &v in &out[..1500] {
assert_eq!(v, 1.0);
}
}
#[test]
fn frame_audio_samples_minimum_one_frame() {
let (out, n) = frame_audio_samples(&[], 640).unwrap();
assert_eq!(n, 1);
assert_eq!(out.len(), 640);
}
#[test]
fn expand_media_placeholders_brackets_and_inlines_tokens() {
let cfg = GemmaMultimodalConfig {
image_token_id: Some(900),
boi_token_id: Some(800),
eoi_token_id: Some(801),
audio_token_id: Some(950),
boa_token_id: Some(850),
eoa_token_index: Some(851),
..Default::default()
};
let chunks = vec![vec![1, 2], vec![3], vec![4, 5]];
let slots = vec![MediaSlot::Image { count: 4 }, MediaSlot::Audio { count: 2 }];
let out = expand_media_placeholders(&chunks, &slots, &cfg).unwrap();
assert_eq!(
out,
vec![
1, 2, 800, 900, 900, 900, 900, 801, 3,
850, 950, 950, 851, 4, 5
],
);
}
#[test]
fn expand_media_placeholders_rejects_mismatched_chunks() {
let cfg = GemmaMultimodalConfig {
image_token_id: Some(900),
..Default::default()
};
let chunks = vec![vec![1]];
let slots = vec![MediaSlot::Image { count: 4 }];
assert!(expand_media_placeholders(&chunks, &slots, &cfg).is_err());
}
#[test]
fn standalone_projector_graphs_only_take_media_as_input() {
let v_cfg = GemmaVisionConfig::default();
let g = build_vision_projection_graph(1, 16, &v_cfg).unwrap();
assert_eq!(g.input_keys, vec!["patches".to_string()]);
let a_cfg = GemmaAudioConfig::default();
let g = build_audio_projection_graph(1, 8, &a_cfg, 3840).unwrap();
assert_eq!(g.input_keys, vec!["frames".to_string()]);
}
#[test]
fn parse_wav_decodes_minimal_pcm16_mono() {
let samples_i16: [i16; 4] = [0, 16_384, -16_384, 32_767];
let mut bytes = Vec::new();
bytes.extend_from_slice(b"RIFF");
let total_size = 4 + (8 + 16) + (8 + samples_i16.len() * 2); bytes.extend_from_slice(&(total_size as u32).to_le_bytes());
bytes.extend_from_slice(b"WAVE");
bytes.extend_from_slice(b"fmt ");
bytes.extend_from_slice(&16u32.to_le_bytes());
bytes.extend_from_slice(&1u16.to_le_bytes()); bytes.extend_from_slice(&1u16.to_le_bytes()); bytes.extend_from_slice(&16_000u32.to_le_bytes()); bytes.extend_from_slice(&32_000u32.to_le_bytes()); bytes.extend_from_slice(&2u16.to_le_bytes()); bytes.extend_from_slice(&16u16.to_le_bytes()); bytes.extend_from_slice(b"data");
bytes.extend_from_slice(&((samples_i16.len() * 2) as u32).to_le_bytes());
for s in samples_i16 {
bytes.extend_from_slice(&s.to_le_bytes());
}
let pcm = parse_wav_16khz_mono(&bytes).unwrap();
assert_eq!(pcm.len(), 4);
assert!((pcm[0] - 0.0).abs() < 1e-4);
assert!((pcm[1] - 0.5).abs() < 1e-3);
assert!((pcm[2] - (-0.5)).abs() < 1e-3);
assert!((pcm[3] - 1.0).abs() < 1e-3);
}
#[test]
fn resample_linear_preserves_constants() {
let src = vec![0.7f32; 1000];
let out = resample_linear(&src, 48_000, 16_000);
assert!((out.len() as i32 - 333).abs() <= 1);
for &v in &out {
assert!((v - 0.7).abs() < 1e-5);
}
}
#[test]
fn tokenize_with_media_splits_and_expands() {
let cfg = GemmaMultimodalConfig {
image_token_id: Some(900),
boi_token_id: Some(800),
eoi_token_id: Some(801),
audio_token_id: Some(950),
boa_token_id: Some(850),
eoa_token_index: Some(851),
..Default::default()
};
let encode = |s: &str| -> Result<Vec<u32>> { Ok(s.bytes().map(|b| b as u32).collect()) };
let prompt = "hi <image> see <audio> bye";
let slots = vec![MediaSlot::Image { count: 2 }, MediaSlot::Audio { count: 1 }];
let out = tokenize_with_media(prompt, &slots, &cfg, encode).unwrap();
let mut expected: Vec<u32> = b"hi ".iter().map(|b| *b as u32).collect();
expected.extend([800, 900, 900, 801]);
expected.extend(b" see ".iter().map(|b| *b as u32));
expected.extend([850, 950, 851]);
expected.extend(b" bye".iter().map(|b| *b as u32));
assert_eq!(out, expected);
}
#[test]
fn tokenize_with_media_rejects_slot_marker_mismatch() {
let cfg = GemmaMultimodalConfig {
image_token_id: Some(900),
..Default::default()
};
let encode = |_: &str| -> Result<Vec<u32>> { Ok(vec![]) };
let err = tokenize_with_media(
"a <image> b <image> c",
&[MediaSlot::Image { count: 1 }],
&cfg,
encode,
)
.unwrap_err();
assert!(err.to_string().contains("media markers"));
}
}