Skip to main content

rlx_gemma/
unified_preprocess.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! CPU preprocessing for Gemma 4 **unified** (12B) image + audio inputs.
17//!
18//! Matches the HuggingFace `Gemma4UnifiedImageProcessor` pipeline:
19//! aspect-ratio resize → `[0,1]` rescale → 16px teacher patchify →
20//! 3×3 patch merge → pad to `max_soft_tokens`.
21
22use anyhow::{Result, bail};
23
24type PatchGrid = Vec<(i32, i32)>;
25type PatchGridResult = (Vec<f32>, PatchGrid);
26
27/// Supported `max_soft_tokens` values (HF Gemma 4 unified).
28pub const SUPPORTED_MAX_SOFT_TOKENS: [usize; 5] = [70, 140, 280, 560, 1120];
29
30/// Max raw audio samples (~30s @ 16 kHz).
31pub const MAX_AUDIO_SAMPLES: usize = 480_000;
32
33/// Pad audio frame count up to a multiple (HF batches pad to 128).
34pub const AUDIO_FRAME_PAD_MULTIPLE: usize = 128;
35
36/// Compute dynamic soft-token count for an image size (before padding).
37pub fn compute_num_soft_tokens_from_size(
38    height: usize,
39    width: usize,
40    patch_size: usize,
41    pooling_kernel_size: usize,
42    max_soft_tokens: usize,
43) -> Result<usize> {
44    let max_patches = max_soft_tokens * pooling_kernel_size * pooling_kernel_size;
45    let (th, tw) =
46        aspect_ratio_preserving_size(height, width, patch_size, max_patches, pooling_kernel_size)?;
47    let teacher = (th / patch_size) * (tw / patch_size);
48    Ok(teacher / (pooling_kernel_size * pooling_kernel_size))
49}
50
51/// Keep only non-padding vision rows from a projected `[max_slots × hidden]` buffer.
52pub fn strip_valid_vision_rows(
53    projected: &[f32],
54    positions: &[(i32, i32)],
55    hidden: usize,
56) -> Vec<f32> {
57    let mut out = Vec::new();
58    let slots = projected.len() / hidden.max(1);
59    for i in 0..slots {
60        let (x, y) = positions.get(i).copied().unwrap_or((-1, -1));
61        if x >= 0 && y >= 0 {
62            out.extend_from_slice(&projected[i * hidden..(i + 1) * hidden]);
63        }
64    }
65    out
66}
67
68/// Truncate + frame-count for unified 12B raw PCM (640 samples/token).
69pub fn unified_audio_token_count(
70    num_samples: usize,
71    samples_per_token: usize,
72    max_tokens: usize,
73) -> usize {
74    let capped = num_samples.min(MAX_AUDIO_SAMPLES);
75    capped.div_ceil(samples_per_token).max(1).min(max_tokens)
76}
77
78pub fn prepare_unified_audio_samples(
79    samples: &[f32],
80    samples_per_token: usize,
81    max_tokens: usize,
82) -> Vec<f32> {
83    let capped_len = samples.len().min(MAX_AUDIO_SAMPLES);
84    let mut truncated = samples[..capped_len].to_vec();
85    let mut num_frames = truncated.len().div_ceil(samples_per_token).max(1);
86    num_frames = num_frames.min(max_tokens);
87    let padded_frames = num_frames.div_ceil(AUDIO_FRAME_PAD_MULTIPLE) * AUDIO_FRAME_PAD_MULTIPLE;
88    truncated.resize(padded_frames * samples_per_token, 0.0);
89    truncated
90}
91
92#[derive(Debug, Clone)]
93pub struct UnifiedImageBatch {
94    /// Merged 48×48 RGB patches, row-major `[num_slots, 6912]`.
95    pub patches: Vec<f32>,
96    /// `(x, y)` grid coords per slot; `(-1, -1)` marks padding.
97    pub positions: Vec<(i32, i32)>,
98    /// Non-padding patch count (before pad-to-max).
99    pub num_valid: usize,
100}
101
102/// Compute target `(height, width)` preserving aspect ratio within the
103/// teacher-patch budget. Port of HF `get_aspect_ratio_preserving_size`.
104pub fn aspect_ratio_preserving_size(
105    height: usize,
106    width: usize,
107    patch_size: usize,
108    max_patches: usize,
109    pooling_kernel_size: usize,
110) -> Result<(usize, usize)> {
111    let total_px = height * width;
112    let target_px = max_patches * patch_size * patch_size;
113    let factor = (target_px as f64 / total_px as f64).sqrt();
114    let ideal_height = factor * height as f64;
115    let ideal_width = factor * width as f64;
116    let side_mult = pooling_kernel_size * patch_size;
117
118    let mut target_height = (ideal_height / side_mult as f64).floor() as usize * side_mult;
119    let mut target_width = (ideal_width / side_mult as f64).floor() as usize * side_mult;
120
121    if target_height == 0 && target_width == 0 {
122        bail!(
123            "resize target is 0×0; image too small for patch_size={patch_size} \
124             pooling_kernel_size={pooling_kernel_size}"
125        );
126    }
127
128    let max_side_length = (max_patches / (pooling_kernel_size * pooling_kernel_size)) * side_mult;
129    if target_height == 0 {
130        target_height = side_mult;
131        target_width =
132            ((width as f64 / height as f64).floor() as usize * side_mult).min(max_side_length);
133    }
134    if target_width == 0 {
135        target_width = side_mult;
136        target_height =
137            ((height as f64 / width as f64).floor() as usize * side_mult).min(max_side_length);
138    }
139    Ok((target_height.max(side_mult), target_width.max(side_mult)))
140}
141
142/// Teacher-level 16×16 patchify from interleaved RGB u8, values in `[0, 1]`.
143pub fn teacher_patches_from_rgb(
144    rgb: &[u8],
145    width: usize,
146    height: usize,
147    patch_size: usize,
148) -> Result<PatchGridResult> {
149    if rgb.len() != width * height * 3 {
150        bail!("rgb len {} != {width}×{height}×3", rgb.len());
151    }
152    let patch_cols = width / patch_size;
153    let patch_rows = height / patch_size;
154    let num = patch_rows * patch_cols;
155    let per = patch_size * patch_size * 3;
156    let inv = 1.0 / 255.0;
157    let mut patches = vec![0f32; num * per];
158    let mut positions = Vec::with_capacity(num);
159    for pr in 0..patch_rows {
160        for pc in 0..patch_cols {
161            let idx = pr * patch_cols + pc;
162            positions.push((pc as i32, pr as i32));
163            let dst_base = idx * per;
164            for py in 0..patch_size {
165                for px in 0..patch_size {
166                    let src = ((pr * patch_size + py) * width + (pc * patch_size + px)) * 3;
167                    let dst = dst_base + (py * patch_size + px) * 3;
168                    patches[dst] = rgb[src] as f32 * inv;
169                    patches[dst + 1] = rgb[src + 1] as f32 * inv;
170                    patches[dst + 2] = rgb[src + 2] as f32 * inv;
171                }
172            }
173        }
174    }
175    Ok((patches, positions))
176}
177
178/// Merge `k×k` teacher patches into 48×48 model patches. Port of HF
179/// `patches_merge` (single batch, no torch).
180pub fn patches_merge(
181    patches: &[f32],
182    positions: &[(i32, i32)],
183    num_model_patches: usize,
184    teacher_patch_dim: usize,
185) -> Result<PatchGridResult> {
186    let l = patches.len() / teacher_patch_dim;
187    if l != num_model_patches {
188        let k2 = l / num_model_patches;
189        if k2 * num_model_patches != l {
190            bail!("cannot merge {l} teacher patches into {num_model_patches} model patches");
191        }
192    }
193    let k = ((l / num_model_patches) as f64).sqrt() as usize;
194    if k * k * num_model_patches != l {
195        bail!("patch count {l} is not num_model×k²");
196    }
197    let patch_size = (teacher_patch_dim / 3).isqrt();
198    let model_dim = (k * patch_size) * (k * patch_size) * 3;
199
200    // Build target ordering (argsort of kernel-group indices).
201    let max_x = positions.iter().map(|(x, _)| *x).max().unwrap_or(0).max(0) as usize + 1;
202    let mut order: Vec<usize> = (0..l).collect();
203    order.sort_by_key(|&i| {
204        let (x, y) = positions[i];
205        let kx = (x as usize) / k;
206        let ky = (y as usize) / k;
207        let num_from_tl = k * k * kx + k * max_x * ky;
208        let px = (x as usize) % k;
209        let py = (y as usize) % k;
210        num_from_tl + px + py * k
211    });
212
213    let mut kernel_ordered: Vec<f32> = vec![0.0; l * teacher_patch_dim];
214    let mut kernel_pos: Vec<(i32, i32)> = vec![(0, 0); l];
215    for (out_i, &src_i) in order.iter().enumerate() {
216        kernel_ordered[out_i * teacher_patch_dim..(out_i + 1) * teacher_patch_dim]
217            .copy_from_slice(&patches[src_i * teacher_patch_dim..(src_i + 1) * teacher_patch_dim]);
218        kernel_pos[out_i] = positions[src_i];
219    }
220
221    let mut merged = vec![0f32; num_model_patches * model_dim];
222    let mut merged_pos = vec![(-1, -1); num_model_patches];
223
224    for mp in 0..num_model_patches {
225        let base = mp * k * k;
226        let mut min_x = i32::MAX;
227        let mut min_y = i32::MAX;
228        let mut out_off = 0usize;
229        for ky in 0..k {
230            for kx in 0..k {
231                let ti = base + ky * k + kx;
232                let (x, y) = kernel_pos[ti];
233                if x >= 0 {
234                    min_x = min_x.min(x / k as i32);
235                    min_y = min_y.min(y / k as i32);
236                }
237                // Spatial merge: place k×k teacher tiles into model patch grid.
238                for py in 0..patch_size {
239                    for px in 0..patch_size {
240                        for c in 0..3 {
241                            let src = ti * teacher_patch_dim + (py * patch_size + px) * 3 + c;
242                            let dst = mp * model_dim
243                                + ((ky * patch_size + py) * (k * patch_size)
244                                    + (kx * patch_size + px))
245                                    * 3
246                                + c;
247                            merged[dst] = kernel_ordered[src];
248                        }
249                    }
250                }
251                out_off += 1;
252            }
253        }
254        let _ = out_off;
255        if min_x != i32::MAX {
256            merged_pos[mp] = (min_x, min_y);
257        }
258    }
259    Ok((merged, merged_pos))
260}
261
262pub fn pad_patches_to_max(
263    patches: Vec<f32>,
264    positions: Vec<(i32, i32)>,
265    model_dim: usize,
266    max_slots: usize,
267) -> (Vec<f32>, Vec<(i32, i32)>) {
268    let n = patches.len() / model_dim;
269    let mut out = vec![0f32; max_slots * model_dim];
270    let mut pos = vec![(-1, -1); max_slots];
271    out[..n * model_dim].copy_from_slice(&patches);
272    pos[..n].copy_from_slice(&positions);
273    (out, pos)
274}
275
276/// Full unified image pipeline from a JPEG/PNG path.
277pub fn load_unified_image(
278    path: impl AsRef<std::path::Path>,
279    patch_size: usize,
280    pooling_kernel_size: usize,
281    max_soft_tokens: usize,
282) -> Result<UnifiedImageBatch> {
283    let img = image::open(path.as_ref())
284        .map_err(|e| anyhow::anyhow!("decode {:?}: {e}", path.as_ref()))?;
285    let rgb = img.to_rgb8();
286    let (w, h) = (rgb.width() as usize, rgb.height() as usize);
287    let max_patches = max_soft_tokens * pooling_kernel_size * pooling_kernel_size;
288    let (th, tw) =
289        aspect_ratio_preserving_size(h, w, patch_size, max_patches, pooling_kernel_size)?;
290    let resized = if (tw, th) != (w, h) {
291        image::DynamicImage::ImageRgb8(rgb)
292            .resize_exact(tw as u32, th as u32, image::imageops::FilterType::Triangle)
293            .to_rgb8()
294    } else {
295        rgb
296    };
297    let (teacher, tpos) = teacher_patches_from_rgb(
298        resized.as_raw(),
299        resized.width() as usize,
300        resized.height() as usize,
301        patch_size,
302    )?;
303    let teacher_dim = patch_size * patch_size * 3;
304    let num_model = teacher.len() / teacher_dim / (pooling_kernel_size * pooling_kernel_size);
305    let (merged, mpos) = patches_merge(&teacher, &tpos, num_model, teacher_dim)?;
306    let model_dim = (patch_size * pooling_kernel_size).pow(2) * 3;
307    let num_valid = num_model;
308    let (patches, positions) = pad_patches_to_max(merged, mpos, model_dim, max_soft_tokens);
309    Ok(UnifiedImageBatch {
310        patches,
311        positions,
312        num_valid,
313    })
314}
315
316/// Factorized 2D positional bias: `pos_embedding[p, axis, d]` with
317/// shape `[posemb_size, 2, dim]` (row-major).
318pub fn factorized_pos_bias(
319    pos_embedding: &[f32],
320    posemb_size: usize,
321    dim: usize,
322    positions: &[(i32, i32)],
323) -> Vec<f32> {
324    let mut out = vec![0f32; positions.len() * dim];
325    for (i, &(x, y)) in positions.iter().enumerate() {
326        if x < 0 || y < 0 {
327            continue;
328        }
329        let x = x as usize;
330        let y = y as usize;
331        if x >= posemb_size || y >= posemb_size {
332            continue;
333        }
334        let x_base = (x * 2) * dim;
335        let y_base = (y * 2 + 1) * dim;
336        for d in 0..dim {
337            out[i * dim + d] = pos_embedding[x_base + d] + pos_embedding[y_base + d];
338        }
339    }
340    out
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[test]
348    fn patches_merge_square_grid() {
349        let k = 3;
350        let ps = 16;
351        let td = ps * ps * 3;
352        let _side = k * 3; // 3 model patches per side → 9 teacher per side? 
353        // 2 model patches: 2*3=6 teacher per side
354        let cols = 6;
355        let rows = 6;
356        let l = cols * rows;
357        let mut patches = vec![0f32; l * td];
358        let mut pos = Vec::new();
359        for r in 0..rows {
360            for c in 0..cols {
361                let i = r * cols + c;
362                pos.push((c as i32, r as i32));
363                patches[i * td] = (i + 1) as f32;
364            }
365        }
366        let num_model = l / (k * k);
367        let (merged, mpos) = patches_merge(&patches, &pos, num_model, td).unwrap();
368        assert_eq!(merged.len(), num_model * (k * ps).pow(2) * 3);
369        assert_eq!(mpos.len(), num_model);
370        assert!(mpos[0].0 >= 0);
371    }
372}