1use super::config::{
24 SAM3_IMG_SIZE, SAM3_PATCH_GRID, SAM3_PIXEL_MEAN, SAM3_PIXEL_STD, Sam3VitConfig,
25};
26use anyhow::{Result, ensure};
27use rlx_core::weight_map::WeightMap;
28
29#[derive(Clone)]
30pub struct Sam3PreprocessWeights {
31 pub patch_proj_w: Vec<f32>,
33 pub patch_proj_b: Vec<f32>,
34 pub pos_embed: Option<Vec<f32>>,
35 pub embed_dim: usize,
36 pub patch_size: usize,
37 pub grid: usize,
38}
39
40pub(crate) fn extract_preprocess_weights(
41 weights: &mut WeightMap,
42 cfg: &Sam3VitConfig,
43) -> Result<Sam3PreprocessWeights> {
44 let e = cfg.embed_dim;
45 let ps = cfg.patch_size;
46 let grid = cfg.patch_grid();
47 let pd = 3 * ps * ps;
48
49 let (proj_raw, proj_shape) = take_first(
50 weights,
51 &[
52 "detector.backbone.vision_backbone.trunk.patch_embed.proj.weight",
53 "detector.backbone.visual.trunk.patch_embed.proj.weight",
54 "backbone.vision_backbone.trunk.patch_embed.proj.weight",
55 "backbone.visual.trunk.patch_embed.proj.weight",
56 "visual.trunk.patch_embed.proj.weight",
57 "trunk.patch_embed.proj.weight",
58 ],
59 )?;
60 ensure!(
61 proj_shape == vec![e, 3, ps, ps],
62 "SAM3 patch_embed.proj.weight expected [{e}, 3, {ps}, {ps}], got {proj_shape:?}"
63 );
64
65 let mut patch_proj_w = vec![0f32; e * pd];
66 for ei in 0..e {
67 for d in 0..pd {
68 patch_proj_w[d * e + ei] = proj_raw[ei * pd + d];
69 }
70 }
71
72 let patch_proj_b = if cfg.bias_patch_embed {
73 let (data, shape) = take_first(
74 weights,
75 &[
76 "detector.backbone.vision_backbone.trunk.patch_embed.proj.bias",
77 "detector.backbone.visual.trunk.patch_embed.proj.bias",
78 "backbone.vision_backbone.trunk.patch_embed.proj.bias",
79 "backbone.visual.trunk.patch_embed.proj.bias",
80 "visual.trunk.patch_embed.proj.bias",
81 "trunk.patch_embed.proj.bias",
82 ],
83 )?;
84 ensure!(
85 shape == vec![e],
86 "SAM3 patch bias expected [{e}], got {shape:?}"
87 );
88 data
89 } else {
90 vec![0.0; e]
91 };
92
93 let pos_embed = if cfg.use_abs_pos {
94 take_optional_first(
95 weights,
96 &[
97 "detector.backbone.vision_backbone.trunk.pos_embed",
98 "detector.backbone.visual.trunk.pos_embed",
99 "backbone.vision_backbone.trunk.pos_embed",
100 "backbone.visual.trunk.pos_embed",
101 "visual.trunk.pos_embed",
102 "trunk.pos_embed",
103 ],
104 )?
105 .map(|(data, shape)| materialize_pos_embed(&data, &shape, cfg, grid, e))
106 .transpose()?
107 } else {
108 None
109 };
110
111 Ok(Sam3PreprocessWeights {
112 patch_proj_w,
113 patch_proj_b,
114 pos_embed,
115 embed_dim: e,
116 patch_size: ps,
117 grid,
118 })
119}
120
121pub fn preprocess_image(rgb: &[u8], h_in: usize, w_in: usize) -> (Vec<f32>, (usize, usize)) {
123 let scale = (SAM3_IMG_SIZE as f32) / (h_in.max(w_in) as f32);
124 let new_h = ((h_in as f32) * scale).round() as usize;
125 let new_w = ((w_in as f32) * scale).round() as usize;
126
127 let mut resized = vec![0f32; 3 * new_h * new_w];
128 let sx = (w_in as f32 - 1.0) / (new_w.max(1) as f32 - 1.0).max(1.0);
129 let sy = (h_in as f32 - 1.0) / (new_h.max(1) as f32 - 1.0).max(1.0);
130 for y in 0..new_h {
131 let fy = y as f32 * sy;
132 let y0 = fy.floor() as usize;
133 let y1 = (y0 + 1).min(h_in - 1);
134 let dy = fy - y0 as f32;
135 for x in 0..new_w {
136 let fx = x as f32 * sx;
137 let x0 = fx.floor() as usize;
138 let x1 = (x0 + 1).min(w_in - 1);
139 let dx = fx - x0 as f32;
140 for c in 0..3 {
141 let p00 = rgb[(y0 * w_in + x0) * 3 + c] as f32 / 255.0;
142 let p01 = rgb[(y0 * w_in + x1) * 3 + c] as f32 / 255.0;
143 let p10 = rgb[(y1 * w_in + x0) * 3 + c] as f32 / 255.0;
144 let p11 = rgb[(y1 * w_in + x1) * 3 + c] as f32 / 255.0;
145 let top = p00 * (1.0 - dx) + p01 * dx;
146 let bot = p10 * (1.0 - dx) + p11 * dx;
147 let v = top * (1.0 - dy) + bot * dy;
148 resized[c * new_h * new_w + y * new_w + x] =
149 (v - SAM3_PIXEL_MEAN[c]) / SAM3_PIXEL_STD[c];
150 }
151 }
152 }
153
154 let mut padded = vec![0f32; 3 * SAM3_IMG_SIZE * SAM3_IMG_SIZE];
155 for c in 0..3 {
156 for y in 0..new_h {
157 let src_row = c * new_h * new_w + y * new_w;
158 let dst_row = c * SAM3_IMG_SIZE * SAM3_IMG_SIZE + y * SAM3_IMG_SIZE;
159 padded[dst_row..dst_row + new_w].copy_from_slice(&resized[src_row..src_row + new_w]);
160 }
161 }
162 (padded, (new_h, new_w))
163}
164
165pub fn assemble_patch_tokens(pre: &Sam3PreprocessWeights, image_nchw: &[f32]) -> Result<Vec<f32>> {
166 let e = pre.embed_dim;
167 let ps = pre.patch_size;
168 let grid = pre.grid;
169 let pd = 3 * ps * ps;
170 ensure!(
171 image_nchw.len() == 3 * SAM3_IMG_SIZE * SAM3_IMG_SIZE,
172 "SAM3 image must be [3, {SAM3_IMG_SIZE}, {SAM3_IMG_SIZE}] NCHW, got len {}",
173 image_nchw.len()
174 );
175 ensure!(
176 grid == SAM3_PATCH_GRID,
177 "SAM3 base grid must be {SAM3_PATCH_GRID}"
178 );
179
180 let mut out = vec![0f32; grid * grid * e];
181 let mut patch_buf = vec![0f32; pd];
182 for py in 0..grid {
183 for px in 0..grid {
184 for c in 0..3 {
185 for ry in 0..ps {
186 let src_y = py * ps + ry;
187 for rx in 0..ps {
188 let src_x = px * ps + rx;
189 let src = c * SAM3_IMG_SIZE * SAM3_IMG_SIZE + src_y * SAM3_IMG_SIZE + src_x;
190 let dst = c * ps * ps + ry * ps + rx;
191 patch_buf[dst] = image_nchw[src];
192 }
193 }
194 }
195 let row = py * grid + px;
196 let dst = &mut out[row * e..(row + 1) * e];
197 dst.copy_from_slice(&pre.patch_proj_b);
198 for d in 0..pd {
199 let v = patch_buf[d];
200 if v == 0.0 {
201 continue;
202 }
203 let w_row = &pre.patch_proj_w[d * e..(d + 1) * e];
204 for k in 0..e {
205 dst[k] += v * w_row[k];
206 }
207 }
208 }
209 }
210
211 if let Some(pos) = &pre.pos_embed {
212 ensure!(pos.len() == out.len(), "SAM3 pos_embed size mismatch");
213 for i in 0..out.len() {
214 out[i] += pos[i];
215 }
216 }
217
218 Ok(out)
219}
220
221fn materialize_pos_embed(
228 data: &[f32],
229 shape: &[usize],
230 cfg: &Sam3VitConfig,
231 grid: usize,
232 e: usize,
233) -> Result<Vec<f32>> {
234 if shape == [1, grid, grid, e] || shape == [grid, grid, e] {
235 return Ok(data.to_vec());
236 }
237 ensure!(
238 shape.len() == 3 && shape[0] == 1 && shape[2] == e,
239 "SAM3 pos_embed expected [1, *, {e}], got {shape:?}"
240 );
241 let num_positions = shape[1];
242 let has_cls = num_positions % 2 == 1;
243 let spatial = if has_cls {
244 num_positions - 1
245 } else {
246 num_positions
247 };
248 let pretrain_grid = (spatial as f64).sqrt().round() as usize;
249 ensure!(
250 pretrain_grid * pretrain_grid == spatial,
251 "SAM3 pos_embed spatial portion not square: {spatial} positions"
252 );
253
254 let src = if has_cls { &data[e..] } else { data };
255 let mut out = vec![0f32; grid * grid * e];
256
257 if cfg.tile_abs_pos {
258 for y in 0..grid {
259 for x in 0..grid {
260 let sy = y % pretrain_grid;
261 let sx = x % pretrain_grid;
262 let src_row = (sy * pretrain_grid + sx) * e;
263 let dst_row = (y * grid + x) * e;
264 out[dst_row..dst_row + e].copy_from_slice(&src[src_row..src_row + e]);
265 }
266 }
267 } else {
268 bicubic_interp_nhwc(src, pretrain_grid, pretrain_grid, &mut out, grid, grid, e);
271 }
272
273 Ok(out)
274}
275
276fn bicubic_interp_nhwc(
277 src: &[f32],
278 src_h: usize,
279 src_w: usize,
280 dst: &mut [f32],
281 dst_h: usize,
282 dst_w: usize,
283 c: usize,
284) {
285 let mut src_chw = vec![0f32; c * src_h * src_w];
287 for y in 0..src_h {
288 for x in 0..src_w {
289 for ch in 0..c {
290 src_chw[ch * src_h * src_w + y * src_w + x] = src[(y * src_w + x) * c + ch];
291 }
292 }
293 }
294 let scale_y = src_h as f32 / dst_h as f32;
295 let scale_x = src_w as f32 / dst_w as f32;
296 for y in 0..dst_h {
297 let fy = (y as f32 + 0.5) * scale_y - 0.5;
298 let y_floor = fy.floor() as i32;
299 let dy = fy - y_floor as f32;
300 let wy = cubic_weights(dy);
301 for x in 0..dst_w {
302 let fx = (x as f32 + 0.5) * scale_x - 0.5;
303 let x_floor = fx.floor() as i32;
304 let dx = fx - x_floor as f32;
305 let wx = cubic_weights(dx);
306 for ch in 0..c {
307 let plane = &src_chw[ch * src_h * src_w..(ch + 1) * src_h * src_w];
308 let mut v = 0.0f32;
309 for j in -1..=2 {
310 let sy = (y_floor + j).clamp(0, src_h as i32 - 1) as usize;
311 let mut row_acc = 0.0f32;
312 for i in -1..=2 {
313 let sx = (x_floor + i).clamp(0, src_w as i32 - 1) as usize;
314 row_acc += plane[sy * src_w + sx] * wx[(i + 1) as usize];
315 }
316 v += row_acc * wy[(j + 1) as usize];
317 }
318 dst[(y * dst_w + x) * c + ch] = v;
319 }
320 }
321 }
322}
323
324fn cubic_weights(t: f32) -> [f32; 4] {
325 let a = -0.75f32;
327 let t1 = 1.0 + t; let t2 = t; let t3 = 1.0 - t; let t4 = 2.0 - t; [
332 cubic_kernel(t1, a),
333 cubic_kernel(t2, a),
334 cubic_kernel(t3, a),
335 cubic_kernel(t4, a),
336 ]
337}
338
339fn cubic_kernel(x: f32, a: f32) -> f32 {
340 let x = x.abs();
341 if x < 1.0 {
342 (a + 2.0) * x * x * x - (a + 3.0) * x * x + 1.0
343 } else if x < 2.0 {
344 a * x * x * x - 5.0 * a * x * x + 8.0 * a * x - 4.0 * a
345 } else {
346 0.0
347 }
348}
349
350fn take_first(weights: &mut WeightMap, keys: &[&str]) -> Result<(Vec<f32>, Vec<usize>)> {
351 for key in keys {
352 if weights.has(key) {
353 return weights.take(key);
354 }
355 }
356 anyhow::bail!("none of the SAM3 weight keys were found: {keys:?}")
357}
358
359fn take_optional_first(
360 weights: &mut WeightMap,
361 keys: &[&str],
362) -> Result<Option<(Vec<f32>, Vec<usize>)>> {
363 for key in keys {
364 if weights.has(key) {
365 return weights.take(key).map(Some);
366 }
367 }
368 Ok(None)
369}