1use anyhow::{Result, bail};
23
24type PatchGrid = Vec<(i32, i32)>;
25type PatchGridResult = (Vec<f32>, PatchGrid);
26
27pub const SUPPORTED_MAX_SOFT_TOKENS: [usize; 5] = [70, 140, 280, 560, 1120];
29
30pub const MAX_AUDIO_SAMPLES: usize = 480_000;
32
33pub const AUDIO_FRAME_PAD_MULTIPLE: usize = 128;
35
36pub fn compute_num_soft_tokens_from_size(
38 height: usize,
39 width: usize,
40 patch_size: usize,
41 pooling_kernel_size: usize,
42 max_soft_tokens: usize,
43) -> Result<usize> {
44 let max_patches = max_soft_tokens * pooling_kernel_size * pooling_kernel_size;
45 let (th, tw) =
46 aspect_ratio_preserving_size(height, width, patch_size, max_patches, pooling_kernel_size)?;
47 let teacher = (th / patch_size) * (tw / patch_size);
48 Ok(teacher / (pooling_kernel_size * pooling_kernel_size))
49}
50
51pub fn strip_valid_vision_rows(
53 projected: &[f32],
54 positions: &[(i32, i32)],
55 hidden: usize,
56) -> Vec<f32> {
57 let mut out = Vec::new();
58 let slots = projected.len() / hidden.max(1);
59 for i in 0..slots {
60 let (x, y) = positions.get(i).copied().unwrap_or((-1, -1));
61 if x >= 0 && y >= 0 {
62 out.extend_from_slice(&projected[i * hidden..(i + 1) * hidden]);
63 }
64 }
65 out
66}
67
68pub fn unified_audio_token_count(
70 num_samples: usize,
71 samples_per_token: usize,
72 max_tokens: usize,
73) -> usize {
74 let capped = num_samples.min(MAX_AUDIO_SAMPLES);
75 capped.div_ceil(samples_per_token).max(1).min(max_tokens)
76}
77
78pub fn prepare_unified_audio_samples(
79 samples: &[f32],
80 samples_per_token: usize,
81 max_tokens: usize,
82) -> Vec<f32> {
83 let capped_len = samples.len().min(MAX_AUDIO_SAMPLES);
84 let mut truncated = samples[..capped_len].to_vec();
85 let mut num_frames = truncated.len().div_ceil(samples_per_token).max(1);
86 num_frames = num_frames.min(max_tokens);
87 let padded_frames = num_frames.div_ceil(AUDIO_FRAME_PAD_MULTIPLE) * AUDIO_FRAME_PAD_MULTIPLE;
88 truncated.resize(padded_frames * samples_per_token, 0.0);
89 truncated
90}
91
92#[derive(Debug, Clone)]
93pub struct UnifiedImageBatch {
94 pub patches: Vec<f32>,
96 pub positions: Vec<(i32, i32)>,
98 pub num_valid: usize,
100}
101
102pub fn aspect_ratio_preserving_size(
105 height: usize,
106 width: usize,
107 patch_size: usize,
108 max_patches: usize,
109 pooling_kernel_size: usize,
110) -> Result<(usize, usize)> {
111 let total_px = height * width;
112 let target_px = max_patches * patch_size * patch_size;
113 let factor = (target_px as f64 / total_px as f64).sqrt();
114 let ideal_height = factor * height as f64;
115 let ideal_width = factor * width as f64;
116 let side_mult = pooling_kernel_size * patch_size;
117
118 let mut target_height = (ideal_height / side_mult as f64).floor() as usize * side_mult;
119 let mut target_width = (ideal_width / side_mult as f64).floor() as usize * side_mult;
120
121 if target_height == 0 && target_width == 0 {
122 bail!(
123 "resize target is 0×0; image too small for patch_size={patch_size} \
124 pooling_kernel_size={pooling_kernel_size}"
125 );
126 }
127
128 let max_side_length = (max_patches / (pooling_kernel_size * pooling_kernel_size)) * side_mult;
129 if target_height == 0 {
130 target_height = side_mult;
131 target_width =
132 ((width as f64 / height as f64).floor() as usize * side_mult).min(max_side_length);
133 }
134 if target_width == 0 {
135 target_width = side_mult;
136 target_height =
137 ((height as f64 / width as f64).floor() as usize * side_mult).min(max_side_length);
138 }
139 Ok((target_height.max(side_mult), target_width.max(side_mult)))
140}
141
142pub fn teacher_patches_from_rgb(
144 rgb: &[u8],
145 width: usize,
146 height: usize,
147 patch_size: usize,
148) -> Result<PatchGridResult> {
149 if rgb.len() != width * height * 3 {
150 bail!("rgb len {} != {width}×{height}×3", rgb.len());
151 }
152 let patch_cols = width / patch_size;
153 let patch_rows = height / patch_size;
154 let num = patch_rows * patch_cols;
155 let per = patch_size * patch_size * 3;
156 let inv = 1.0 / 255.0;
157 let mut patches = vec![0f32; num * per];
158 let mut positions = Vec::with_capacity(num);
159 for pr in 0..patch_rows {
160 for pc in 0..patch_cols {
161 let idx = pr * patch_cols + pc;
162 positions.push((pc as i32, pr as i32));
163 let dst_base = idx * per;
164 for py in 0..patch_size {
165 for px in 0..patch_size {
166 let src = ((pr * patch_size + py) * width + (pc * patch_size + px)) * 3;
167 let dst = dst_base + (py * patch_size + px) * 3;
168 patches[dst] = rgb[src] as f32 * inv;
169 patches[dst + 1] = rgb[src + 1] as f32 * inv;
170 patches[dst + 2] = rgb[src + 2] as f32 * inv;
171 }
172 }
173 }
174 }
175 Ok((patches, positions))
176}
177
178pub fn patches_merge(
181 patches: &[f32],
182 positions: &[(i32, i32)],
183 num_model_patches: usize,
184 teacher_patch_dim: usize,
185) -> Result<PatchGridResult> {
186 let l = patches.len() / teacher_patch_dim;
187 if l != num_model_patches {
188 let k2 = l / num_model_patches;
189 if k2 * num_model_patches != l {
190 bail!("cannot merge {l} teacher patches into {num_model_patches} model patches");
191 }
192 }
193 let k = ((l / num_model_patches) as f64).sqrt() as usize;
194 if k * k * num_model_patches != l {
195 bail!("patch count {l} is not num_model×k²");
196 }
197 let patch_size = (teacher_patch_dim / 3).isqrt();
198 let model_dim = (k * patch_size) * (k * patch_size) * 3;
199
200 let max_x = positions.iter().map(|(x, _)| *x).max().unwrap_or(0).max(0) as usize + 1;
202 let mut order: Vec<usize> = (0..l).collect();
203 order.sort_by_key(|&i| {
204 let (x, y) = positions[i];
205 let kx = (x as usize) / k;
206 let ky = (y as usize) / k;
207 let num_from_tl = k * k * kx + k * max_x * ky;
208 let px = (x as usize) % k;
209 let py = (y as usize) % k;
210 num_from_tl + px + py * k
211 });
212
213 let mut kernel_ordered: Vec<f32> = vec![0.0; l * teacher_patch_dim];
214 let mut kernel_pos: Vec<(i32, i32)> = vec![(0, 0); l];
215 for (out_i, &src_i) in order.iter().enumerate() {
216 kernel_ordered[out_i * teacher_patch_dim..(out_i + 1) * teacher_patch_dim]
217 .copy_from_slice(&patches[src_i * teacher_patch_dim..(src_i + 1) * teacher_patch_dim]);
218 kernel_pos[out_i] = positions[src_i];
219 }
220
221 let mut merged = vec![0f32; num_model_patches * model_dim];
222 let mut merged_pos = vec![(-1, -1); num_model_patches];
223
224 for mp in 0..num_model_patches {
225 let base = mp * k * k;
226 let mut min_x = i32::MAX;
227 let mut min_y = i32::MAX;
228 let mut out_off = 0usize;
229 for ky in 0..k {
230 for kx in 0..k {
231 let ti = base + ky * k + kx;
232 let (x, y) = kernel_pos[ti];
233 if x >= 0 {
234 min_x = min_x.min(x / k as i32);
235 min_y = min_y.min(y / k as i32);
236 }
237 for py in 0..patch_size {
239 for px in 0..patch_size {
240 for c in 0..3 {
241 let src = ti * teacher_patch_dim + (py * patch_size + px) * 3 + c;
242 let dst = mp * model_dim
243 + ((ky * patch_size + py) * (k * patch_size)
244 + (kx * patch_size + px))
245 * 3
246 + c;
247 merged[dst] = kernel_ordered[src];
248 }
249 }
250 }
251 out_off += 1;
252 }
253 }
254 let _ = out_off;
255 if min_x != i32::MAX {
256 merged_pos[mp] = (min_x, min_y);
257 }
258 }
259 Ok((merged, merged_pos))
260}
261
262pub fn pad_patches_to_max(
263 patches: Vec<f32>,
264 positions: Vec<(i32, i32)>,
265 model_dim: usize,
266 max_slots: usize,
267) -> (Vec<f32>, Vec<(i32, i32)>) {
268 let n = patches.len() / model_dim;
269 let mut out = vec![0f32; max_slots * model_dim];
270 let mut pos = vec![(-1, -1); max_slots];
271 out[..n * model_dim].copy_from_slice(&patches);
272 pos[..n].copy_from_slice(&positions);
273 (out, pos)
274}
275
276pub fn load_unified_image(
278 path: impl AsRef<std::path::Path>,
279 patch_size: usize,
280 pooling_kernel_size: usize,
281 max_soft_tokens: usize,
282) -> Result<UnifiedImageBatch> {
283 let img = image::open(path.as_ref())
284 .map_err(|e| anyhow::anyhow!("decode {:?}: {e}", path.as_ref()))?;
285 let rgb = img.to_rgb8();
286 let (w, h) = (rgb.width() as usize, rgb.height() as usize);
287 let max_patches = max_soft_tokens * pooling_kernel_size * pooling_kernel_size;
288 let (th, tw) =
289 aspect_ratio_preserving_size(h, w, patch_size, max_patches, pooling_kernel_size)?;
290 let resized = if (tw, th) != (w, h) {
291 image::DynamicImage::ImageRgb8(rgb)
292 .resize_exact(tw as u32, th as u32, image::imageops::FilterType::Triangle)
293 .to_rgb8()
294 } else {
295 rgb
296 };
297 let (teacher, tpos) = teacher_patches_from_rgb(
298 resized.as_raw(),
299 resized.width() as usize,
300 resized.height() as usize,
301 patch_size,
302 )?;
303 let teacher_dim = patch_size * patch_size * 3;
304 let num_model = teacher.len() / teacher_dim / (pooling_kernel_size * pooling_kernel_size);
305 let (merged, mpos) = patches_merge(&teacher, &tpos, num_model, teacher_dim)?;
306 let model_dim = (patch_size * pooling_kernel_size).pow(2) * 3;
307 let num_valid = num_model;
308 let (patches, positions) = pad_patches_to_max(merged, mpos, model_dim, max_soft_tokens);
309 Ok(UnifiedImageBatch {
310 patches,
311 positions,
312 num_valid,
313 })
314}
315
316pub fn factorized_pos_bias(
319 pos_embedding: &[f32],
320 posemb_size: usize,
321 dim: usize,
322 positions: &[(i32, i32)],
323) -> Vec<f32> {
324 let mut out = vec![0f32; positions.len() * dim];
325 for (i, &(x, y)) in positions.iter().enumerate() {
326 if x < 0 || y < 0 {
327 continue;
328 }
329 let x = x as usize;
330 let y = y as usize;
331 if x >= posemb_size || y >= posemb_size {
332 continue;
333 }
334 let x_base = (x * 2) * dim;
335 let y_base = (y * 2 + 1) * dim;
336 for d in 0..dim {
337 out[i * dim + d] = pos_embedding[x_base + d] + pos_embedding[y_base + d];
338 }
339 }
340 out
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[test]
348 fn patches_merge_square_grid() {
349 let k = 3;
350 let ps = 16;
351 let td = ps * ps * 3;
352 let _side = k * 3; let cols = 6;
355 let rows = 6;
356 let l = cols * rows;
357 let mut patches = vec![0f32; l * td];
358 let mut pos = Vec::new();
359 for r in 0..rows {
360 for c in 0..cols {
361 let i = r * cols + c;
362 pos.push((c as i32, r as i32));
363 patches[i * td] = (i + 1) as f32;
364 }
365 }
366 let num_model = l / (k * k);
367 let (merged, mpos) = patches_merge(&patches, &pos, num_model, td).unwrap();
368 assert_eq!(merged.len(), num_model * (k * ps).pow(2) * 3);
369 assert_eq!(mpos.len(), num_model);
370 assert!(mpos[0].0 >= 0);
371 }
372}