rlx_locateanything/
preprocess.rs1use crate::config::{LocateAnythingConfig, LocateAnythingPreprocessorConfig, MoonVitConfig};
19use anyhow::{Result, ensure};
20use image::DynamicImage;
21use image::imageops::FilterType;
22use std::path::Path;
23
24#[derive(Debug, Clone)]
26pub struct PreprocessedImage {
27 pub patches: Vec<f32>,
28 pub grid_h: usize,
29 pub grid_w: usize,
30 pub patch_dim: usize,
31 pub pixel_w: u32,
33 pub pixel_h: u32,
34}
35
36impl PreprocessedImage {
37 pub fn num_patches(&self) -> usize {
38 self.grid_h * self.grid_w
39 }
40}
41
42pub fn preprocess_image(
43 img: &DynamicImage,
44 cfg: &LocateAnythingConfig,
45) -> Result<PreprocessedImage> {
46 preprocess_image_with_limit(img, &cfg.vision_config, &cfg.preprocessor)
47}
48
49pub fn preprocess_path(path: &Path, cfg: &LocateAnythingConfig) -> Result<PreprocessedImage> {
50 let img = image::open(path)?;
51 preprocess_image(&img, cfg)
52}
53
54fn preprocess_image_with_limit(
55 img: &DynamicImage,
56 vit: &MoonVitConfig,
57 pre: &LocateAnythingPreprocessorConfig,
58) -> Result<PreprocessedImage> {
59 let patch_size = vit.patch_size;
60 let in_token_limit = pre.in_token_limit;
61 let merge_kernel = vit.merge_kernel_size;
62 let mean = pre.image_mean;
63 let std = pre.image_std;
64
65 let mut rgb = img.to_rgb8();
66 let (mut w, mut h) = rgb.dimensions();
67
68 let patches_before_merge = (w as usize / patch_size) * (h as usize / patch_size);
69 if patches_before_merge > in_token_limit {
70 let scale = (in_token_limit as f32 / patches_before_merge as f32).sqrt();
71 let new_w = (w as f32 * scale) as u32;
72 let new_h = (h as f32 * scale) as u32;
73 rgb = image::DynamicImage::ImageRgb8(rgb)
74 .resize_exact(new_w.max(1), new_h.max(1), FilterType::CatmullRom)
75 .to_rgb8();
76 w = rgb.width();
77 h = rgb.height();
78 }
79
80 let pad_h = merge_kernel[0] * patch_size;
81 let pad_w = merge_kernel[1] * patch_size;
82 let target_w = (w as usize).div_ceil(pad_w) * pad_w;
83 let target_h = (h as usize).div_ceil(pad_h) * pad_h;
84
85 if target_w != w as usize || target_h != h as usize {
86 rgb = image::DynamicImage::ImageRgb8(rgb)
87 .resize_exact(target_w as u32, target_h as u32, FilterType::CatmullRom)
88 .to_rgb8();
89 w = rgb.width();
90 h = rgb.height();
91 }
92
93 let grid_h = h as usize / patch_size;
94 let grid_w = w as usize / patch_size;
95 ensure!(
96 grid_h < 512 && grid_w < 512,
97 "grid {grid_h}x{grid_w} exceeds position embedding limit"
98 );
99 let mut tensor = vec![0f32; 3 * h as usize * w as usize];
100 for y in 0..h as usize {
101 for x in 0..w as usize {
102 let p = rgb.get_pixel(x as u32, y as u32);
103 for c in 0..3 {
104 let v = p[c] as f32 / 255.0;
105 tensor[c * h as usize * w as usize + y * w as usize + x] = (v - mean[c]) / std[c];
106 }
107 }
108 }
109
110 let patch_dim = 3 * patch_size * patch_size;
111 let num_patches = grid_h * grid_w;
112 let mut patches = vec![0f32; num_patches * patch_dim];
113
114 for py in 0..grid_h {
115 for px in 0..grid_w {
116 let out_patch = (py * grid_w + px) * patch_dim;
117 for c in 0..3 {
118 for dy in 0..patch_size {
119 for dx in 0..patch_size {
120 let y = py * patch_size + dy;
121 let x = px * patch_size + dx;
122 let src = c * h as usize * w as usize + y * w as usize + x;
123 let dst = out_patch + c * patch_size * patch_size + dy * patch_size + dx;
124 patches[dst] = tensor[src];
125 }
126 }
127 }
128 }
129 }
130
131 Ok(PreprocessedImage {
132 patches,
133 grid_h,
134 grid_w,
135 patch_dim,
136 pixel_w: w,
137 pixel_h: h,
138 })
139}