use super::config::{
SAM2_IMG_SIZE, SAM2_PATCH_GRID, SAM2_PATCH_KERNEL, SAM2_PATCH_PADDING, SAM2_PATCH_STRIDE,
SAM2_PIXEL_MEAN, SAM2_PIXEL_STD, Sam2HieraConfig,
};
use anyhow::{Result, ensure};
use rlx_core::weight_map::WeightMap;
pub struct Sam2PreprocessWeights {
pub patch_proj_w: Vec<f32>,
pub patch_proj_b: Vec<f32>,
pub pos_embed_full: Vec<f32>,
pub embed_dim: usize,
pub grid: usize, }
pub(super) fn extract_preprocess_weights(
weights: &mut WeightMap,
cfg: &Sam2HieraConfig,
) -> Result<Sam2PreprocessWeights> {
let e = cfg.embed_dim;
let k = SAM2_PATCH_KERNEL;
let grid = SAM2_PATCH_GRID;
let (patch_proj_w, w_shape) = weights.take("image_encoder.trunk.patch_embed.proj.weight")?;
ensure!(
w_shape == vec![e, 3, k, k],
"patch_embed.proj.weight expected [{e}, 3, {k}, {k}], got {w_shape:?}"
);
let (patch_proj_b, _) = weights.take("image_encoder.trunk.patch_embed.proj.bias")?;
let (pe_raw, pe_shape) = weights.take("image_encoder.trunk.pos_embed")?;
let [ph, pw] = cfg.window_pos_embed_bkg_spatial_size;
ensure!(
pe_shape == vec![1, e, ph, pw],
"pos_embed expected [1, {e}, {ph}, {pw}], got {pe_shape:?}"
);
let mu = cfg.window_size_at_stage(0);
let (pew_raw, pew_shape) = weights.take("image_encoder.trunk.pos_embed_window")?;
ensure!(
pew_shape == vec![1, e, mu, mu],
"pos_embed_window expected [1, {e}, {mu}, {mu}], got {pew_shape:?}"
);
let pos_embed_full = build_full_pos_embed(&pe_raw, &pew_raw, e, ph, pw, mu, grid);
Ok(Sam2PreprocessWeights {
patch_proj_w,
patch_proj_b,
pos_embed_full,
embed_dim: e,
grid,
})
}
fn build_full_pos_embed(
pe: &[f32],
pew: &[f32],
e: usize,
ph: usize,
pw: usize,
mu: usize,
grid: usize,
) -> Vec<f32> {
debug_assert_eq!(pe.len(), e * ph * pw);
debug_assert_eq!(pew.len(), e * mu * mu);
debug_assert_eq!(
grid % mu,
0,
"Hiera pos_embed_window must tile grid evenly (grid={grid}, mu={mu})"
);
let mut interp_pe = vec![0f32; e * grid * grid];
for c in 0..e {
let src = &pe[c * ph * pw..(c + 1) * ph * pw];
let dst = &mut interp_pe[c * grid * grid..(c + 1) * grid * grid];
bicubic_resize_2d(src, ph, pw, dst, grid, grid);
}
let mut out_nchw = interp_pe; for c in 0..e {
for y in 0..grid {
let ty = y % mu;
for x in 0..grid {
let tx = x % mu;
let w_val = pew[c * mu * mu + ty * mu + tx];
out_nchw[c * grid * grid + y * grid + x] += w_val;
}
}
}
let mut out_bhwc = vec![0f32; grid * grid * e];
for y in 0..grid {
for x in 0..grid {
for c in 0..e {
out_bhwc[(y * grid + x) * e + c] = out_nchw[c * grid * grid + y * grid + x];
}
}
}
out_bhwc
}
fn bicubic_resize_2d(
src: &[f32],
h_in: usize,
w_in: usize,
dst: &mut [f32],
h_out: usize,
w_out: usize,
) {
fn cubic(t: f32) -> f32 {
let a = -0.75_f32;
let t = t.abs();
if t < 1.0 {
((a + 2.0) * t - (a + 3.0)) * t * t + 1.0
} else if t < 2.0 {
(((t - 5.0) * t + 8.0) * t - 4.0) * a
} else {
0.0
}
}
fn idx(i: isize, max: isize) -> usize {
i.clamp(0, max - 1) as usize
}
let sx = (w_in as f32) / (w_out as f32);
let sy = (h_in as f32) / (h_out as f32);
for y_o in 0..h_out {
let yf = (y_o as f32 + 0.5) * sy - 0.5;
let yi = yf.floor();
let dy = yf - yi;
let wy = [cubic(1.0 + dy), cubic(dy), cubic(1.0 - dy), cubic(2.0 - dy)];
for x_o in 0..w_out {
let xf = (x_o as f32 + 0.5) * sx - 0.5;
let xi = xf.floor();
let dx = xf - xi;
let wx = [cubic(1.0 + dx), cubic(dx), cubic(1.0 - dx), cubic(2.0 - dx)];
let mut acc = 0f32;
for jy in 0..4 {
let iy = idx(yi as isize - 1 + jy, h_in as isize);
for jx in 0..4 {
let ix = idx(xi as isize - 1 + jx as isize, w_in as isize);
acc += src[iy * w_in + ix] * wy[jy as usize] * wx[jx];
}
}
dst[y_o * w_out + x_o] = acc;
}
}
}
pub fn preprocess_image(rgb: &[u8], h_in: usize, w_in: usize) -> Vec<f32> {
debug_assert_eq!(rgb.len(), h_in * w_in * 3);
let out_size = SAM2_IMG_SIZE;
let mut nchw = vec![0f32; 3 * out_size * out_size];
let sx = (w_in as f32) / (out_size as f32);
let sy = (h_in as f32) / (out_size as f32);
for y_o in 0..out_size {
let yf = (y_o as f32 + 0.5) * sy - 0.5;
let y0 = yf.floor().max(0.0) as usize;
let y1 = (y0 + 1).min(h_in - 1);
let dy = (yf - yf.floor()).clamp(0.0, 1.0);
for x_o in 0..out_size {
let xf = (x_o as f32 + 0.5) * sx - 0.5;
let x0 = xf.floor().max(0.0) as usize;
let x1 = (x0 + 1).min(w_in - 1);
let dx = (xf - xf.floor()).clamp(0.0, 1.0);
for c in 0..3 {
let p00 = rgb[(y0 * w_in + x0) * 3 + c] as f32;
let p01 = rgb[(y0 * w_in + x1) * 3 + c] as f32;
let p10 = rgb[(y1 * w_in + x0) * 3 + c] as f32;
let p11 = rgb[(y1 * w_in + x1) * 3 + c] as f32;
let top = p00 * (1.0 - dx) + p01 * dx;
let bot = p10 * (1.0 - dx) + p11 * dx;
let v01 = (top * (1.0 - dy) + bot * dy) / 255.0;
nchw[c * out_size * out_size + y_o * out_size + x_o] =
(v01 - SAM2_PIXEL_MEAN[c]) / SAM2_PIXEL_STD[c];
}
}
}
nchw
}
pub fn assemble_patch_tokens(pre: &Sam2PreprocessWeights, image_nchw: &[f32]) -> Result<Vec<f32>> {
let e = pre.embed_dim;
let grid = pre.grid;
let k = SAM2_PATCH_KERNEL;
let s = SAM2_PATCH_STRIDE;
let pad = SAM2_PATCH_PADDING;
ensure!(
image_nchw.len() == 3 * SAM2_IMG_SIZE * SAM2_IMG_SIZE,
"image must be [3, {}, {}] NCHW, got len {}",
SAM2_IMG_SIZE,
SAM2_IMG_SIZE,
image_nchw.len()
);
let h = SAM2_IMG_SIZE;
let w = SAM2_IMG_SIZE;
let mut out = vec![0f32; grid * grid * e];
for py in 0..grid {
for px in 0..grid {
let dst = &mut out[(py * grid + px) * e..(py * grid + px + 1) * e];
dst.copy_from_slice(&pre.patch_proj_b);
for ky in 0..k {
let iy = (py * s) as isize + ky as isize - pad as isize;
if iy < 0 || iy >= h as isize {
continue;
}
let iy = iy as usize;
for kx in 0..k {
let ix = (px * s) as isize + kx as isize - pad as isize;
if ix < 0 || ix >= w as isize {
continue;
}
let ix = ix as usize;
for c in 0..3 {
let v = image_nchw[c * h * w + iy * w + ix];
let w_base = c * k * k + ky * k + kx;
let stride = 3 * k * k;
for ei in 0..e {
dst[ei] += v * pre.patch_proj_w[ei * stride + w_base];
}
}
}
}
}
}
ensure!(
pre.pos_embed_full.len() == grid * grid * e,
"pos_embed_full size mismatch"
);
for i in 0..grid * grid * e {
out[i] += pre.pos_embed_full[i];
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn preprocess_shape_and_range() {
let img = vec![128u8; 50 * 30 * 3];
let nchw = preprocess_image(&img, 50, 30);
assert_eq!(nchw.len(), 3 * 1024 * 1024);
for c in 0..3 {
let expected = (128.0 / 255.0 - SAM2_PIXEL_MEAN[c]) / SAM2_PIXEL_STD[c];
let mid = nchw[c * 1024 * 1024 + 512 * 1024 + 512];
assert!(
(mid - expected).abs() < 1e-4,
"channel {c}: {mid} vs {expected}"
);
}
}
#[test]
fn bicubic_identity() {
let src: Vec<f32> = (0..64).map(|i| i as f32).collect();
let mut dst = vec![0f32; 64];
bicubic_resize_2d(&src, 8, 8, &mut dst, 8, 8);
for i in 0..64 {
assert!(
(src[i] - dst[i]).abs() < 1e-4,
"identity broken at {i}: {} vs {}",
src[i],
dst[i]
);
}
}
}