use anyhow::{Result, bail};
type PatchGrid = Vec<(i32, i32)>;
type PatchGridResult = (Vec<f32>, PatchGrid);
pub const SUPPORTED_MAX_SOFT_TOKENS: [usize; 5] = [70, 140, 280, 560, 1120];
pub const MAX_AUDIO_SAMPLES: usize = 480_000;
pub const AUDIO_FRAME_PAD_MULTIPLE: usize = 128;
pub fn compute_num_soft_tokens_from_size(
height: usize,
width: usize,
patch_size: usize,
pooling_kernel_size: usize,
max_soft_tokens: usize,
) -> Result<usize> {
let max_patches = max_soft_tokens * pooling_kernel_size * pooling_kernel_size;
let (th, tw) =
aspect_ratio_preserving_size(height, width, patch_size, max_patches, pooling_kernel_size)?;
let teacher = (th / patch_size) * (tw / patch_size);
Ok(teacher / (pooling_kernel_size * pooling_kernel_size))
}
pub fn strip_valid_vision_rows(
projected: &[f32],
positions: &[(i32, i32)],
hidden: usize,
) -> Vec<f32> {
let mut out = Vec::new();
let slots = projected.len() / hidden.max(1);
for i in 0..slots {
let (x, y) = positions.get(i).copied().unwrap_or((-1, -1));
if x >= 0 && y >= 0 {
out.extend_from_slice(&projected[i * hidden..(i + 1) * hidden]);
}
}
out
}
pub fn unified_audio_token_count(
num_samples: usize,
samples_per_token: usize,
max_tokens: usize,
) -> usize {
let capped = num_samples.min(MAX_AUDIO_SAMPLES);
capped.div_ceil(samples_per_token).max(1).min(max_tokens)
}
pub fn prepare_unified_audio_samples(
samples: &[f32],
samples_per_token: usize,
max_tokens: usize,
) -> Vec<f32> {
let capped_len = samples.len().min(MAX_AUDIO_SAMPLES);
let mut truncated = samples[..capped_len].to_vec();
let mut num_frames = truncated.len().div_ceil(samples_per_token).max(1);
num_frames = num_frames.min(max_tokens);
let padded_frames = num_frames.div_ceil(AUDIO_FRAME_PAD_MULTIPLE) * AUDIO_FRAME_PAD_MULTIPLE;
truncated.resize(padded_frames * samples_per_token, 0.0);
truncated
}
#[derive(Debug, Clone)]
pub struct UnifiedImageBatch {
pub patches: Vec<f32>,
pub positions: Vec<(i32, i32)>,
pub num_valid: usize,
}
pub fn aspect_ratio_preserving_size(
height: usize,
width: usize,
patch_size: usize,
max_patches: usize,
pooling_kernel_size: usize,
) -> Result<(usize, usize)> {
let total_px = height * width;
let target_px = max_patches * patch_size * patch_size;
let factor = (target_px as f64 / total_px as f64).sqrt();
let ideal_height = factor * height as f64;
let ideal_width = factor * width as f64;
let side_mult = pooling_kernel_size * patch_size;
let mut target_height = (ideal_height / side_mult as f64).floor() as usize * side_mult;
let mut target_width = (ideal_width / side_mult as f64).floor() as usize * side_mult;
if target_height == 0 && target_width == 0 {
bail!(
"resize target is 0×0; image too small for patch_size={patch_size} \
pooling_kernel_size={pooling_kernel_size}"
);
}
let max_side_length = (max_patches / (pooling_kernel_size * pooling_kernel_size)) * side_mult;
if target_height == 0 {
target_height = side_mult;
target_width =
((width as f64 / height as f64).floor() as usize * side_mult).min(max_side_length);
}
if target_width == 0 {
target_width = side_mult;
target_height =
((height as f64 / width as f64).floor() as usize * side_mult).min(max_side_length);
}
Ok((target_height.max(side_mult), target_width.max(side_mult)))
}
pub fn teacher_patches_from_rgb(
rgb: &[u8],
width: usize,
height: usize,
patch_size: usize,
) -> Result<PatchGridResult> {
if rgb.len() != width * height * 3 {
bail!("rgb len {} != {width}×{height}×3", rgb.len());
}
let patch_cols = width / patch_size;
let patch_rows = height / patch_size;
let num = patch_rows * patch_cols;
let per = patch_size * patch_size * 3;
let inv = 1.0 / 255.0;
let mut patches = vec![0f32; num * per];
let mut positions = Vec::with_capacity(num);
for pr in 0..patch_rows {
for pc in 0..patch_cols {
let idx = pr * patch_cols + pc;
positions.push((pc as i32, pr as i32));
let dst_base = idx * per;
for py in 0..patch_size {
for px in 0..patch_size {
let src = ((pr * patch_size + py) * width + (pc * patch_size + px)) * 3;
let dst = dst_base + (py * patch_size + px) * 3;
patches[dst] = rgb[src] as f32 * inv;
patches[dst + 1] = rgb[src + 1] as f32 * inv;
patches[dst + 2] = rgb[src + 2] as f32 * inv;
}
}
}
}
Ok((patches, positions))
}
pub fn patches_merge(
patches: &[f32],
positions: &[(i32, i32)],
num_model_patches: usize,
teacher_patch_dim: usize,
) -> Result<PatchGridResult> {
let l = patches.len() / teacher_patch_dim;
if l != num_model_patches {
let k2 = l / num_model_patches;
if k2 * num_model_patches != l {
bail!("cannot merge {l} teacher patches into {num_model_patches} model patches");
}
}
let k = ((l / num_model_patches) as f64).sqrt() as usize;
if k * k * num_model_patches != l {
bail!("patch count {l} is not num_model×k²");
}
let patch_size = (teacher_patch_dim / 3).isqrt();
let model_dim = (k * patch_size) * (k * patch_size) * 3;
let max_x = positions.iter().map(|(x, _)| *x).max().unwrap_or(0).max(0) as usize + 1;
let mut order: Vec<usize> = (0..l).collect();
order.sort_by_key(|&i| {
let (x, y) = positions[i];
let kx = (x as usize) / k;
let ky = (y as usize) / k;
let num_from_tl = k * k * kx + k * max_x * ky;
let px = (x as usize) % k;
let py = (y as usize) % k;
num_from_tl + px + py * k
});
let mut kernel_ordered: Vec<f32> = vec![0.0; l * teacher_patch_dim];
let mut kernel_pos: Vec<(i32, i32)> = vec![(0, 0); l];
for (out_i, &src_i) in order.iter().enumerate() {
kernel_ordered[out_i * teacher_patch_dim..(out_i + 1) * teacher_patch_dim]
.copy_from_slice(&patches[src_i * teacher_patch_dim..(src_i + 1) * teacher_patch_dim]);
kernel_pos[out_i] = positions[src_i];
}
let mut merged = vec![0f32; num_model_patches * model_dim];
let mut merged_pos = vec![(-1, -1); num_model_patches];
for mp in 0..num_model_patches {
let base = mp * k * k;
let mut min_x = i32::MAX;
let mut min_y = i32::MAX;
let mut out_off = 0usize;
for ky in 0..k {
for kx in 0..k {
let ti = base + ky * k + kx;
let (x, y) = kernel_pos[ti];
if x >= 0 {
min_x = min_x.min(x / k as i32);
min_y = min_y.min(y / k as i32);
}
for py in 0..patch_size {
for px in 0..patch_size {
for c in 0..3 {
let src = ti * teacher_patch_dim + (py * patch_size + px) * 3 + c;
let dst = mp * model_dim
+ ((ky * patch_size + py) * (k * patch_size)
+ (kx * patch_size + px))
* 3
+ c;
merged[dst] = kernel_ordered[src];
}
}
}
out_off += 1;
}
}
let _ = out_off;
if min_x != i32::MAX {
merged_pos[mp] = (min_x, min_y);
}
}
Ok((merged, merged_pos))
}
pub fn pad_patches_to_max(
patches: Vec<f32>,
positions: Vec<(i32, i32)>,
model_dim: usize,
max_slots: usize,
) -> (Vec<f32>, Vec<(i32, i32)>) {
let n = patches.len() / model_dim;
let mut out = vec![0f32; max_slots * model_dim];
let mut pos = vec![(-1, -1); max_slots];
out[..n * model_dim].copy_from_slice(&patches);
pos[..n].copy_from_slice(&positions);
(out, pos)
}
pub fn load_unified_image(
path: impl AsRef<std::path::Path>,
patch_size: usize,
pooling_kernel_size: usize,
max_soft_tokens: usize,
) -> Result<UnifiedImageBatch> {
let img = image::open(path.as_ref())
.map_err(|e| anyhow::anyhow!("decode {:?}: {e}", path.as_ref()))?;
let rgb = img.to_rgb8();
let (w, h) = (rgb.width() as usize, rgb.height() as usize);
let max_patches = max_soft_tokens * pooling_kernel_size * pooling_kernel_size;
let (th, tw) =
aspect_ratio_preserving_size(h, w, patch_size, max_patches, pooling_kernel_size)?;
let resized = if (tw, th) != (w, h) {
image::DynamicImage::ImageRgb8(rgb)
.resize_exact(tw as u32, th as u32, image::imageops::FilterType::Triangle)
.to_rgb8()
} else {
rgb
};
let (teacher, tpos) = teacher_patches_from_rgb(
resized.as_raw(),
resized.width() as usize,
resized.height() as usize,
patch_size,
)?;
let teacher_dim = patch_size * patch_size * 3;
let num_model = teacher.len() / teacher_dim / (pooling_kernel_size * pooling_kernel_size);
let (merged, mpos) = patches_merge(&teacher, &tpos, num_model, teacher_dim)?;
let model_dim = (patch_size * pooling_kernel_size).pow(2) * 3;
let num_valid = num_model;
let (patches, positions) = pad_patches_to_max(merged, mpos, model_dim, max_soft_tokens);
Ok(UnifiedImageBatch {
patches,
positions,
num_valid,
})
}
pub fn factorized_pos_bias(
pos_embedding: &[f32],
posemb_size: usize,
dim: usize,
positions: &[(i32, i32)],
) -> Vec<f32> {
let mut out = vec![0f32; positions.len() * dim];
for (i, &(x, y)) in positions.iter().enumerate() {
if x < 0 || y < 0 {
continue;
}
let x = x as usize;
let y = y as usize;
if x >= posemb_size || y >= posemb_size {
continue;
}
let x_base = (x * 2) * dim;
let y_base = (y * 2 + 1) * dim;
for d in 0..dim {
out[i * dim + d] = pos_embedding[x_base + d] + pos_embedding[y_base + d];
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn patches_merge_square_grid() {
let k = 3;
let ps = 16;
let td = ps * ps * 3;
let _side = k * 3; let cols = 6;
let rows = 6;
let l = cols * rows;
let mut patches = vec![0f32; l * td];
let mut pos = Vec::new();
for r in 0..rows {
for c in 0..cols {
let i = r * cols + c;
pos.push((c as i32, r as i32));
patches[i * td] = (i + 1) as f32;
}
}
let num_model = l / (k * k);
let (merged, mpos) = patches_merge(&patches, &pos, num_model, td).unwrap();
assert_eq!(merged.len(), num_model * (k * ps).pow(2) * 3);
assert_eq!(mpos.len(), num_model);
assert!(mpos[0].0 >= 0);
}
}