1use super::config::{
35 SAM2_IMG_SIZE, SAM2_PATCH_GRID, SAM2_PATCH_KERNEL, SAM2_PATCH_PADDING, SAM2_PATCH_STRIDE,
36 SAM2_PIXEL_MEAN, SAM2_PIXEL_STD, Sam2HieraConfig,
37};
38use anyhow::{Result, ensure};
39use rlx_core::weight_map::WeightMap;
40
41pub struct Sam2PreprocessWeights {
44 pub patch_proj_w: Vec<f32>,
48 pub patch_proj_b: Vec<f32>,
50 pub pos_embed_full: Vec<f32>,
54 pub embed_dim: usize,
55 pub grid: usize, }
57
58pub(super) fn extract_preprocess_weights(
59 weights: &mut WeightMap,
60 cfg: &Sam2HieraConfig,
61) -> Result<Sam2PreprocessWeights> {
62 let e = cfg.embed_dim;
63 let k = SAM2_PATCH_KERNEL;
64 let grid = SAM2_PATCH_GRID;
65
66 let (patch_proj_w, w_shape) = weights.take("image_encoder.trunk.patch_embed.proj.weight")?;
68 ensure!(
69 w_shape == vec![e, 3, k, k],
70 "patch_embed.proj.weight expected [{e}, 3, {k}, {k}], got {w_shape:?}"
71 );
72 let (patch_proj_b, _) = weights.take("image_encoder.trunk.patch_embed.proj.bias")?;
73
74 let (pe_raw, pe_shape) = weights.take("image_encoder.trunk.pos_embed")?;
77 let [ph, pw] = cfg.window_pos_embed_bkg_spatial_size;
78 ensure!(
79 pe_shape == vec![1, e, ph, pw],
80 "pos_embed expected [1, {e}, {ph}, {pw}], got {pe_shape:?}"
81 );
82
83 let mu = cfg.window_size_at_stage(0);
84 let (pew_raw, pew_shape) = weights.take("image_encoder.trunk.pos_embed_window")?;
85 ensure!(
86 pew_shape == vec![1, e, mu, mu],
87 "pos_embed_window expected [1, {e}, {mu}, {mu}], got {pew_shape:?}"
88 );
89
90 let pos_embed_full = build_full_pos_embed(&pe_raw, &pew_raw, e, ph, pw, mu, grid);
91
92 Ok(Sam2PreprocessWeights {
93 patch_proj_w,
94 patch_proj_b,
95 pos_embed_full,
96 embed_dim: e,
97 grid,
98 })
99}
100
101fn build_full_pos_embed(
106 pe: &[f32],
107 pew: &[f32],
108 e: usize,
109 ph: usize,
110 pw: usize,
111 mu: usize,
112 grid: usize,
113) -> Vec<f32> {
114 debug_assert_eq!(pe.len(), e * ph * pw);
115 debug_assert_eq!(pew.len(), e * mu * mu);
116 debug_assert_eq!(
117 grid % mu,
118 0,
119 "Hiera pos_embed_window must tile grid evenly (grid={grid}, mu={mu})"
120 );
121
122 let mut interp_pe = vec![0f32; e * grid * grid];
124 for c in 0..e {
125 let src = &pe[c * ph * pw..(c + 1) * ph * pw];
126 let dst = &mut interp_pe[c * grid * grid..(c + 1) * grid * grid];
127 bicubic_resize_2d(src, ph, pw, dst, grid, grid);
128 }
129
130 let mut out_nchw = interp_pe; for c in 0..e {
134 for y in 0..grid {
135 let ty = y % mu;
136 for x in 0..grid {
137 let tx = x % mu;
138 let w_val = pew[c * mu * mu + ty * mu + tx];
139 out_nchw[c * grid * grid + y * grid + x] += w_val;
140 }
141 }
142 }
143
144 let mut out_bhwc = vec![0f32; grid * grid * e];
146 for y in 0..grid {
147 for x in 0..grid {
148 for c in 0..e {
149 out_bhwc[(y * grid + x) * e + c] = out_nchw[c * grid * grid + y * grid + x];
150 }
151 }
152 }
153 out_bhwc
154}
155
156fn bicubic_resize_2d(
163 src: &[f32],
164 h_in: usize,
165 w_in: usize,
166 dst: &mut [f32],
167 h_out: usize,
168 w_out: usize,
169) {
170 fn cubic(t: f32) -> f32 {
171 let a = -0.75_f32;
173 let t = t.abs();
174 if t < 1.0 {
175 ((a + 2.0) * t - (a + 3.0)) * t * t + 1.0
176 } else if t < 2.0 {
177 (((t - 5.0) * t + 8.0) * t - 4.0) * a
178 } else {
179 0.0
180 }
181 }
182 fn idx(i: isize, max: isize) -> usize {
183 i.clamp(0, max - 1) as usize
185 }
186
187 let sx = (w_in as f32) / (w_out as f32);
188 let sy = (h_in as f32) / (h_out as f32);
189
190 for y_o in 0..h_out {
191 let yf = (y_o as f32 + 0.5) * sy - 0.5;
193 let yi = yf.floor();
194 let dy = yf - yi;
195 let wy = [cubic(1.0 + dy), cubic(dy), cubic(1.0 - dy), cubic(2.0 - dy)];
196 for x_o in 0..w_out {
197 let xf = (x_o as f32 + 0.5) * sx - 0.5;
198 let xi = xf.floor();
199 let dx = xf - xi;
200 let wx = [cubic(1.0 + dx), cubic(dx), cubic(1.0 - dx), cubic(2.0 - dx)];
201
202 let mut acc = 0f32;
203 for jy in 0..4 {
204 let iy = idx(yi as isize - 1 + jy, h_in as isize);
205 for jx in 0..4 {
206 let ix = idx(xi as isize - 1 + jx as isize, w_in as isize);
207 acc += src[iy * w_in + ix] * wy[jy as usize] * wx[jx];
208 }
209 }
210 dst[y_o * w_out + x_o] = acc;
211 }
212 }
213}
214
215pub fn preprocess_image(rgb: &[u8], h_in: usize, w_in: usize) -> Vec<f32> {
223 debug_assert_eq!(rgb.len(), h_in * w_in * 3);
224 let out_size = SAM2_IMG_SIZE;
225 let mut nchw = vec![0f32; 3 * out_size * out_size];
226
227 let sx = (w_in as f32) / (out_size as f32);
229 let sy = (h_in as f32) / (out_size as f32);
230
231 for y_o in 0..out_size {
232 let yf = (y_o as f32 + 0.5) * sy - 0.5;
233 let y0 = yf.floor().max(0.0) as usize;
234 let y1 = (y0 + 1).min(h_in - 1);
235 let dy = (yf - yf.floor()).clamp(0.0, 1.0);
236 for x_o in 0..out_size {
237 let xf = (x_o as f32 + 0.5) * sx - 0.5;
238 let x0 = xf.floor().max(0.0) as usize;
239 let x1 = (x0 + 1).min(w_in - 1);
240 let dx = (xf - xf.floor()).clamp(0.0, 1.0);
241 for c in 0..3 {
242 let p00 = rgb[(y0 * w_in + x0) * 3 + c] as f32;
243 let p01 = rgb[(y0 * w_in + x1) * 3 + c] as f32;
244 let p10 = rgb[(y1 * w_in + x0) * 3 + c] as f32;
245 let p11 = rgb[(y1 * w_in + x1) * 3 + c] as f32;
246 let top = p00 * (1.0 - dx) + p01 * dx;
247 let bot = p10 * (1.0 - dx) + p11 * dx;
248 let v01 = (top * (1.0 - dy) + bot * dy) / 255.0;
249 nchw[c * out_size * out_size + y_o * out_size + x_o] =
250 (v01 - SAM2_PIXEL_MEAN[c]) / SAM2_PIXEL_STD[c];
251 }
252 }
253 }
254 nchw
255}
256
257pub fn assemble_patch_tokens(pre: &Sam2PreprocessWeights, image_nchw: &[f32]) -> Result<Vec<f32>> {
263 let e = pre.embed_dim;
264 let grid = pre.grid;
265 let k = SAM2_PATCH_KERNEL;
266 let s = SAM2_PATCH_STRIDE;
267 let pad = SAM2_PATCH_PADDING;
268 ensure!(
269 image_nchw.len() == 3 * SAM2_IMG_SIZE * SAM2_IMG_SIZE,
270 "image must be [3, {}, {}] NCHW, got len {}",
271 SAM2_IMG_SIZE,
272 SAM2_IMG_SIZE,
273 image_nchw.len()
274 );
275
276 let h = SAM2_IMG_SIZE;
277 let w = SAM2_IMG_SIZE;
278 let mut out = vec![0f32; grid * grid * e];
279
280 for py in 0..grid {
284 for px in 0..grid {
285 let dst = &mut out[(py * grid + px) * e..(py * grid + px + 1) * e];
287 dst.copy_from_slice(&pre.patch_proj_b);
289 for ky in 0..k {
291 let iy = (py * s) as isize + ky as isize - pad as isize;
292 if iy < 0 || iy >= h as isize {
293 continue;
294 }
295 let iy = iy as usize;
296 for kx in 0..k {
297 let ix = (px * s) as isize + kx as isize - pad as isize;
298 if ix < 0 || ix >= w as isize {
299 continue;
300 }
301 let ix = ix as usize;
302 for c in 0..3 {
303 let v = image_nchw[c * h * w + iy * w + ix];
304 let w_base = c * k * k + ky * k + kx;
306 let stride = 3 * k * k;
307 for ei in 0..e {
308 dst[ei] += v * pre.patch_proj_w[ei * stride + w_base];
309 }
310 }
311 }
312 }
313 }
314 }
315
316 ensure!(
318 pre.pos_embed_full.len() == grid * grid * e,
319 "pos_embed_full size mismatch"
320 );
321 for i in 0..grid * grid * e {
322 out[i] += pre.pos_embed_full[i];
323 }
324 Ok(out)
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330
331 #[test]
332 fn preprocess_shape_and_range() {
333 let img = vec![128u8; 50 * 30 * 3];
335 let nchw = preprocess_image(&img, 50, 30);
336 assert_eq!(nchw.len(), 3 * 1024 * 1024);
337 for c in 0..3 {
339 let expected = (128.0 / 255.0 - SAM2_PIXEL_MEAN[c]) / SAM2_PIXEL_STD[c];
340 let mid = nchw[c * 1024 * 1024 + 512 * 1024 + 512];
341 assert!(
342 (mid - expected).abs() < 1e-4,
343 "channel {c}: {mid} vs {expected}"
344 );
345 }
346 }
347
348 #[test]
349 fn bicubic_identity() {
350 let src: Vec<f32> = (0..64).map(|i| i as f32).collect();
353 let mut dst = vec![0f32; 64];
354 bicubic_resize_2d(&src, 8, 8, &mut dst, 8, 8);
355 for i in 0..64 {
356 assert!(
357 (src[i] - dst[i]).abs() < 1e-4,
358 "identity broken at {i}: {} vs {}",
359 src[i],
360 dst[i]
361 );
362 }
363 }
364}