Skip to main content

rlx_locateanything/
preprocess.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Image preprocessing — HF `LocateAnythingImageProcessor` (native resolution).
17
18use crate::config::{LocateAnythingConfig, LocateAnythingPreprocessorConfig, MoonVitConfig};
19use anyhow::{Result, ensure};
20use image::DynamicImage;
21use image::imageops::FilterType;
22use std::path::Path;
23
24/// Flattened patch tensor `[num_patches, C * patch_h * patch_w]` and grid `(h_patches, w_patches)`.
25#[derive(Debug, Clone)]
26pub struct PreprocessedImage {
27    pub patches: Vec<f32>,
28    pub grid_h: usize,
29    pub grid_w: usize,
30    pub patch_dim: usize,
31    /// Width/height in pixels after rescale + pad (used for box coordinate scaling).
32    pub pixel_w: u32,
33    pub pixel_h: u32,
34}
35
36impl PreprocessedImage {
37    pub fn num_patches(&self) -> usize {
38        self.grid_h * self.grid_w
39    }
40}
41
42pub fn preprocess_image(
43    img: &DynamicImage,
44    cfg: &LocateAnythingConfig,
45) -> Result<PreprocessedImage> {
46    preprocess_image_with_limit(img, &cfg.vision_config, &cfg.preprocessor)
47}
48
49pub fn preprocess_path(path: &Path, cfg: &LocateAnythingConfig) -> Result<PreprocessedImage> {
50    let img = image::open(path)?;
51    preprocess_image(&img, cfg)
52}
53
54fn preprocess_image_with_limit(
55    img: &DynamicImage,
56    vit: &MoonVitConfig,
57    pre: &LocateAnythingPreprocessorConfig,
58) -> Result<PreprocessedImage> {
59    let patch_size = vit.patch_size;
60    let in_token_limit = pre.in_token_limit;
61    let merge_kernel = vit.merge_kernel_size;
62    let mean = pre.image_mean;
63    let std = pre.image_std;
64
65    let mut rgb = img.to_rgb8();
66    let (mut w, mut h) = rgb.dimensions();
67
68    let patches_before_merge = (w as usize / patch_size) * (h as usize / patch_size);
69    if patches_before_merge > in_token_limit {
70        let scale = (in_token_limit as f32 / patches_before_merge as f32).sqrt();
71        let new_w = (w as f32 * scale) as u32;
72        let new_h = (h as f32 * scale) as u32;
73        rgb = image::DynamicImage::ImageRgb8(rgb)
74            .resize_exact(new_w.max(1), new_h.max(1), FilterType::CatmullRom)
75            .to_rgb8();
76        w = rgb.width();
77        h = rgb.height();
78    }
79
80    let pad_h = merge_kernel[0] * patch_size;
81    let pad_w = merge_kernel[1] * patch_size;
82    let target_w = (w as usize).div_ceil(pad_w) * pad_w;
83    let target_h = (h as usize).div_ceil(pad_h) * pad_h;
84
85    if target_w != w as usize || target_h != h as usize {
86        rgb = image::DynamicImage::ImageRgb8(rgb)
87            .resize_exact(target_w as u32, target_h as u32, FilterType::CatmullRom)
88            .to_rgb8();
89        w = rgb.width();
90        h = rgb.height();
91    }
92
93    let grid_h = h as usize / patch_size;
94    let grid_w = w as usize / patch_size;
95    ensure!(
96        grid_h < 512 && grid_w < 512,
97        "grid {grid_h}x{grid_w} exceeds position embedding limit"
98    );
99    let mut tensor = vec![0f32; 3 * h as usize * w as usize];
100    for y in 0..h as usize {
101        for x in 0..w as usize {
102            let p = rgb.get_pixel(x as u32, y as u32);
103            for c in 0..3 {
104                let v = p[c] as f32 / 255.0;
105                tensor[c * h as usize * w as usize + y * w as usize + x] = (v - mean[c]) / std[c];
106            }
107        }
108    }
109
110    let patch_dim = 3 * patch_size * patch_size;
111    let num_patches = grid_h * grid_w;
112    let mut patches = vec![0f32; num_patches * patch_dim];
113
114    for py in 0..grid_h {
115        for px in 0..grid_w {
116            let out_patch = (py * grid_w + px) * patch_dim;
117            for c in 0..3 {
118                for dy in 0..patch_size {
119                    for dx in 0..patch_size {
120                        let y = py * patch_size + dy;
121                        let x = px * patch_size + dx;
122                        let src = c * h as usize * w as usize + y * w as usize + x;
123                        let dst = out_patch + c * patch_size * patch_size + dy * patch_size + dx;
124                        patches[dst] = tensor[src];
125                    }
126                }
127            }
128        }
129    }
130
131    Ok(PreprocessedImage {
132        patches,
133        grid_h,
134        grid_w,
135        patch_dim,
136        pixel_w: w,
137        pixel_h: h,
138    })
139}