use super::config::{
SAM_IMG_SIZE, SAM_PATCH_SIZE, SAM_PIXEL_MEAN, SAM_PIXEL_STD, SamEncoderConfig,
};
use anyhow::{Result, ensure};
use rlx_core::weight_map::WeightMap;
pub struct SamPreprocessWeights {
pub patch_proj_w: Vec<f32>,
pub patch_proj_b: Vec<f32>,
pub pos_embed: Option<Vec<f32>>,
pub embed_dim: usize,
pub hw: usize,
}
pub(super) fn extract_preprocess_weights(
weights: &mut WeightMap,
cfg: &SamEncoderConfig,
) -> Result<SamPreprocessWeights> {
let e = cfg.embed_dim;
let hw = cfg.num_patches_per_side();
let pd = 3 * SAM_PATCH_SIZE * SAM_PATCH_SIZE;
let (proj_raw, proj_shape) = weights.take("image_encoder.patch_embed.proj.weight")?;
ensure!(
proj_shape == vec![e, 3, SAM_PATCH_SIZE, SAM_PATCH_SIZE],
"patch_embed.proj.weight expected [{e}, 3, {SAM_PATCH_SIZE}, {SAM_PATCH_SIZE}], got {proj_shape:?}"
);
let mut patch_proj_w = vec![0f32; e * pd];
for ei in 0..e {
for d in 0..pd {
patch_proj_w[d * e + ei] = proj_raw[ei * pd + d];
}
}
let (patch_proj_b, _) = weights.take("image_encoder.patch_embed.proj.bias")?;
let pos_embed = if cfg.use_abs_pos {
let (data, shape) = weights.take("image_encoder.pos_embed")?;
ensure!(
shape == vec![1, hw, hw, e],
"pos_embed expected [1, {hw}, {hw}, {e}], got {shape:?}"
);
Some(data)
} else {
None
};
Ok(SamPreprocessWeights {
patch_proj_w,
patch_proj_b,
pos_embed,
embed_dim: e,
hw,
})
}
pub fn preprocess_image(rgb: &[u8], h_in: usize, w_in: usize) -> (Vec<f32>, (usize, usize)) {
let scale = (SAM_IMG_SIZE as f32) / (h_in.max(w_in) as f32);
let new_h = ((h_in as f32) * scale).round() as usize;
let new_w = ((w_in as f32) * scale).round() as usize;
let mut resized = vec![0f32; 3 * new_h * new_w];
let sx = (w_in as f32 - 1.0) / (new_w.max(1) as f32 - 1.0).max(1.0);
let sy = (h_in as f32 - 1.0) / (new_h.max(1) as f32 - 1.0).max(1.0);
for y in 0..new_h {
let fy = y as f32 * sy;
let y0 = fy.floor() as usize;
let y1 = (y0 + 1).min(h_in - 1);
let dy = fy - y0 as f32;
for x in 0..new_w {
let fx = x as f32 * sx;
let x0 = fx.floor() as usize;
let x1 = (x0 + 1).min(w_in - 1);
let dx = fx - x0 as f32;
for c in 0..3 {
let p00 = rgb[(y0 * w_in + x0) * 3 + c] as f32;
let p01 = rgb[(y0 * w_in + x1) * 3 + c] as f32;
let p10 = rgb[(y1 * w_in + x0) * 3 + c] as f32;
let p11 = rgb[(y1 * w_in + x1) * 3 + c] as f32;
let top = p00 * (1.0 - dx) + p01 * dx;
let bot = p10 * (1.0 - dx) + p11 * dx;
let v = top * (1.0 - dy) + bot * dy;
resized[c * new_h * new_w + y * new_w + x] =
(v - SAM_PIXEL_MEAN[c]) / SAM_PIXEL_STD[c];
}
}
}
let mut padded = vec![0f32; 3 * SAM_IMG_SIZE * SAM_IMG_SIZE];
for c in 0..3 {
for y in 0..new_h {
let src_row = c * new_h * new_w + y * new_w;
let dst_row = c * SAM_IMG_SIZE * SAM_IMG_SIZE + y * SAM_IMG_SIZE;
padded[dst_row..dst_row + new_w].copy_from_slice(&resized[src_row..src_row + new_w]);
}
}
(padded, (new_h, new_w))
}
pub fn assemble_patch_tokens(pre: &SamPreprocessWeights, image_nchw: &[f32]) -> Result<Vec<f32>> {
let e = pre.embed_dim;
let hw = pre.hw;
let pd = 3 * SAM_PATCH_SIZE * SAM_PATCH_SIZE;
ensure!(
image_nchw.len() == 3 * SAM_IMG_SIZE * SAM_IMG_SIZE,
"image must be [3, {SAM_IMG_SIZE}, {SAM_IMG_SIZE}] NCHW, got len {}",
image_nchw.len()
);
let mut out = vec![0f32; hw * hw * e];
let mut patch_buf = vec![0f32; pd];
for py in 0..hw {
for px in 0..hw {
for c in 0..3 {
for ry in 0..SAM_PATCH_SIZE {
let src_y = py * SAM_PATCH_SIZE + ry;
for rx in 0..SAM_PATCH_SIZE {
let src_x = px * SAM_PATCH_SIZE + rx;
let src = c * SAM_IMG_SIZE * SAM_IMG_SIZE + src_y * SAM_IMG_SIZE + src_x;
let dst = c * SAM_PATCH_SIZE * SAM_PATCH_SIZE + ry * SAM_PATCH_SIZE + rx;
patch_buf[dst] = image_nchw[src];
}
}
}
let row = py * hw + px;
let dst = &mut out[row * e..(row + 1) * e];
dst.copy_from_slice(&pre.patch_proj_b);
for d in 0..pd {
let v = patch_buf[d];
if v == 0.0 {
continue;
}
let w_row = &pre.patch_proj_w[d * e..(d + 1) * e];
for k in 0..e {
dst[k] += v * w_row[k];
}
}
}
}
if let Some(pos) = &pre.pos_embed {
ensure!(pos.len() == hw * hw * e, "pos_embed size mismatch");
for i in 0..hw * hw * e {
out[i] += pos[i];
}
}
Ok(out)
}