1use super::config::{
31 SAM_IMG_SIZE, SAM_PATCH_SIZE, SAM_PIXEL_MEAN, SAM_PIXEL_STD, SamEncoderConfig,
32};
33use anyhow::{Result, ensure};
34use rlx_core::weight_map::WeightMap;
35
36pub struct SamPreprocessWeights {
39 pub patch_proj_w: Vec<f32>,
42 pub patch_proj_b: Vec<f32>,
44 pub pos_embed: Option<Vec<f32>>,
48 pub embed_dim: usize,
49 pub hw: usize,
50}
51
52pub(super) fn extract_preprocess_weights(
53 weights: &mut WeightMap,
54 cfg: &SamEncoderConfig,
55) -> Result<SamPreprocessWeights> {
56 let e = cfg.embed_dim;
57 let hw = cfg.num_patches_per_side();
58 let pd = 3 * SAM_PATCH_SIZE * SAM_PATCH_SIZE;
59
60 let (proj_raw, proj_shape) = weights.take("image_encoder.patch_embed.proj.weight")?;
62 ensure!(
63 proj_shape == vec![e, 3, SAM_PATCH_SIZE, SAM_PATCH_SIZE],
64 "patch_embed.proj.weight expected [{e}, 3, {SAM_PATCH_SIZE}, {SAM_PATCH_SIZE}], got {proj_shape:?}"
65 );
66 let mut patch_proj_w = vec![0f32; e * pd];
69 for ei in 0..e {
70 for d in 0..pd {
71 patch_proj_w[d * e + ei] = proj_raw[ei * pd + d];
72 }
73 }
74 let (patch_proj_b, _) = weights.take("image_encoder.patch_embed.proj.bias")?;
75
76 let pos_embed = if cfg.use_abs_pos {
77 let (data, shape) = weights.take("image_encoder.pos_embed")?;
78 ensure!(
79 shape == vec![1, hw, hw, e],
80 "pos_embed expected [1, {hw}, {hw}, {e}], got {shape:?}"
81 );
82 Some(data)
83 } else {
84 None
85 };
86
87 Ok(SamPreprocessWeights {
88 patch_proj_w,
89 patch_proj_b,
90 pos_embed,
91 embed_dim: e,
92 hw,
93 })
94}
95
96pub fn preprocess_image(rgb: &[u8], h_in: usize, w_in: usize) -> (Vec<f32>, (usize, usize)) {
104 let scale = (SAM_IMG_SIZE as f32) / (h_in.max(w_in) as f32);
105 let new_h = ((h_in as f32) * scale).round() as usize;
106 let new_w = ((w_in as f32) * scale).round() as usize;
107 let mut resized = vec![0f32; 3 * new_h * new_w];
109 let sx = (w_in as f32 - 1.0) / (new_w.max(1) as f32 - 1.0).max(1.0);
110 let sy = (h_in as f32 - 1.0) / (new_h.max(1) as f32 - 1.0).max(1.0);
111 for y in 0..new_h {
112 let fy = y as f32 * sy;
113 let y0 = fy.floor() as usize;
114 let y1 = (y0 + 1).min(h_in - 1);
115 let dy = fy - y0 as f32;
116 for x in 0..new_w {
117 let fx = x as f32 * sx;
118 let x0 = fx.floor() as usize;
119 let x1 = (x0 + 1).min(w_in - 1);
120 let dx = fx - x0 as f32;
121 for c in 0..3 {
122 let p00 = rgb[(y0 * w_in + x0) * 3 + c] as f32;
123 let p01 = rgb[(y0 * w_in + x1) * 3 + c] as f32;
124 let p10 = rgb[(y1 * w_in + x0) * 3 + c] as f32;
125 let p11 = rgb[(y1 * w_in + x1) * 3 + c] as f32;
126 let top = p00 * (1.0 - dx) + p01 * dx;
127 let bot = p10 * (1.0 - dx) + p11 * dx;
128 let v = top * (1.0 - dy) + bot * dy;
129 resized[c * new_h * new_w + y * new_w + x] =
131 (v - SAM_PIXEL_MEAN[c]) / SAM_PIXEL_STD[c];
132 }
133 }
134 }
135 let mut padded = vec![0f32; 3 * SAM_IMG_SIZE * SAM_IMG_SIZE];
137 for c in 0..3 {
138 for y in 0..new_h {
139 let src_row = c * new_h * new_w + y * new_w;
140 let dst_row = c * SAM_IMG_SIZE * SAM_IMG_SIZE + y * SAM_IMG_SIZE;
141 padded[dst_row..dst_row + new_w].copy_from_slice(&resized[src_row..src_row + new_w]);
142 }
143 }
144 (padded, (new_h, new_w))
145}
146
147pub fn assemble_patch_tokens(pre: &SamPreprocessWeights, image_nchw: &[f32]) -> Result<Vec<f32>> {
152 let e = pre.embed_dim;
153 let hw = pre.hw;
154 let pd = 3 * SAM_PATCH_SIZE * SAM_PATCH_SIZE;
155 ensure!(
156 image_nchw.len() == 3 * SAM_IMG_SIZE * SAM_IMG_SIZE,
157 "image must be [3, {SAM_IMG_SIZE}, {SAM_IMG_SIZE}] NCHW, got len {}",
158 image_nchw.len()
159 );
160
161 let mut out = vec![0f32; hw * hw * e];
162 let mut patch_buf = vec![0f32; pd];
163 for py in 0..hw {
164 for px in 0..hw {
165 for c in 0..3 {
168 for ry in 0..SAM_PATCH_SIZE {
169 let src_y = py * SAM_PATCH_SIZE + ry;
170 for rx in 0..SAM_PATCH_SIZE {
171 let src_x = px * SAM_PATCH_SIZE + rx;
172 let src = c * SAM_IMG_SIZE * SAM_IMG_SIZE + src_y * SAM_IMG_SIZE + src_x;
173 let dst = c * SAM_PATCH_SIZE * SAM_PATCH_SIZE + ry * SAM_PATCH_SIZE + rx;
174 patch_buf[dst] = image_nchw[src];
175 }
176 }
177 }
178 let row = py * hw + px;
180 let dst = &mut out[row * e..(row + 1) * e];
181 dst.copy_from_slice(&pre.patch_proj_b);
182 for d in 0..pd {
183 let v = patch_buf[d];
184 if v == 0.0 {
185 continue;
186 }
187 let w_row = &pre.patch_proj_w[d * e..(d + 1) * e];
188 for k in 0..e {
189 dst[k] += v * w_row[k];
190 }
191 }
192 }
193 }
194
195 if let Some(pos) = &pre.pos_embed {
197 ensure!(pos.len() == hw * hw * e, "pos_embed size mismatch");
198 for i in 0..hw * hw * e {
199 out[i] += pos[i];
200 }
201 }
202 Ok(out)
203}