1use super::config::SAM3_DET_DIM;
31use super::neck_branch_ir::Sam3NeckBranchCompiled;
32use super::vision_encoder::Sam3VisionOutput;
33use anyhow::{Result, ensure};
34use rlx_core::weight_map::WeightMap;
35use rlx_flow::CompileProfile;
36use rlx_runtime::Device;
37
38#[derive(Debug, Clone)]
39pub struct Sam3FeatureLevel {
40 pub features: Vec<f32>, pub pos: Vec<f32>,
42 pub h: usize,
43 pub w: usize,
44 pub channels: usize,
45}
46
47#[derive(Default)]
48pub struct Sam3NeckWeights {
49 pub loaded: bool,
50 pub branches: Vec<Sam3NeckBranch>,
51 pub ir: Vec<Sam3NeckBranchCompiled>,
53}
54
55#[derive(Clone, Default)]
56pub struct Sam3NeckBranch {
57 pub scale: f32,
58 pub dconv0_w: Option<Vec<f32>>,
60 pub dconv0_b: Option<Vec<f32>>,
61 pub dconv1_w: Option<Vec<f32>>,
63 pub dconv1_b: Option<Vec<f32>>,
64 pub c1x1_w: Vec<f32>,
66 pub c1x1_b: Vec<f32>,
67 pub c1x1_in: usize,
68 pub c3x3_w: Vec<f32>,
70 pub c3x3_b: Vec<f32>,
71}
72
73pub fn extract_neck_weights(weights: &mut WeightMap) -> Result<Sam3NeckWeights> {
74 let prefixes = [
75 "detector.backbone.vision_backbone",
76 "backbone.vision_backbone",
77 "vision_backbone",
78 ];
79 let scales = [4.0f32, 2.0, 1.0, 0.5];
80 let mut branches = Vec::with_capacity(scales.len());
81 for (i, scale) in scales.iter().enumerate() {
82 let mut found = None;
83 for pref in prefixes {
84 let base = format!("{pref}.convs.{i}");
85 if weights.has(&format!("{base}.conv_1x1.weight")) {
86 found = Some(base);
87 break;
88 }
89 }
90 let base = found.ok_or_else(|| {
91 anyhow::anyhow!("SAM3 neck branch {i} (scale={scale}) not found in checkpoint")
92 })?;
93
94 let (dconv0_w, dconv0_b) = if (*scale - 4.0).abs() < 1e-6 {
95 let (w, ws) = weights.take(&format!("{base}.dconv_2x2_0.weight"))?;
96 ensure!(
97 ws == vec![1024, 512, 2, 2],
98 "dconv_2x2_0.weight shape {ws:?}"
99 );
100 let (b, _) = weights.take(&format!("{base}.dconv_2x2_0.bias"))?;
101 (Some(w), Some(b))
102 } else if (*scale - 2.0).abs() < 1e-6 {
103 let (w, ws) = weights.take(&format!("{base}.dconv_2x2.weight"))?;
104 ensure!(ws == vec![1024, 512, 2, 2], "dconv_2x2.weight shape {ws:?}");
105 let (b, _) = weights.take(&format!("{base}.dconv_2x2.bias"))?;
106 (Some(w), Some(b))
107 } else {
108 (None, None)
109 };
110 let (dconv1_w, dconv1_b) = if (*scale - 4.0).abs() < 1e-6 {
111 let (w, ws) = weights.take(&format!("{base}.dconv_2x2_1.weight"))?;
112 ensure!(
113 ws == vec![512, 256, 2, 2],
114 "dconv_2x2_1.weight shape {ws:?}"
115 );
116 let (b, _) = weights.take(&format!("{base}.dconv_2x2_1.bias"))?;
117 (Some(w), Some(b))
118 } else {
119 (None, None)
120 };
121
122 let (c1x1_w, c1_shape) = weights.take(&format!("{base}.conv_1x1.weight"))?;
123 ensure!(c1_shape.len() == 4 && c1_shape[2] == 1 && c1_shape[3] == 1);
124 let c1x1_in = c1_shape[1];
125 let (c1x1_b, _) = weights.take(&format!("{base}.conv_1x1.bias"))?;
126 let (c3x3_w, c3_shape) = weights.take(&format!("{base}.conv_3x3.weight"))?;
127 ensure!(
128 c3_shape == vec![SAM3_DET_DIM, SAM3_DET_DIM, 3, 3],
129 "conv_3x3.weight shape {c3_shape:?}"
130 );
131 let (c3x3_b, _) = weights.take(&format!("{base}.conv_3x3.bias"))?;
132
133 branches.push(Sam3NeckBranch {
134 scale: *scale,
135 dconv0_w,
136 dconv0_b,
137 dconv1_w,
138 dconv1_b,
139 c1x1_w,
140 c1x1_b,
141 c1x1_in,
142 c3x3_w,
143 c3x3_b,
144 });
145 }
146
147 for pref in prefixes {
149 let base = format!("{pref}.sam2_convs");
150 let keys: Vec<String> = weights
151 .keys()
152 .filter(|k| k.starts_with(&base))
153 .map(|s| s.to_string())
154 .collect();
155 for k in keys {
156 let _ = weights.take(&k);
157 }
158 }
159
160 Ok(Sam3NeckWeights {
161 loaded: true,
162 branches,
163 ir: Vec::new(),
164 })
165}
166
167pub fn compile_neck_branches(
169 neck: &mut Sam3NeckWeights,
170 in_c: usize,
171 grid: usize,
172 device: Device,
173 profile: &CompileProfile,
174) -> Result<()> {
175 neck.ir = neck
176 .branches
177 .iter()
178 .map(|b| Sam3NeckBranchCompiled::compile_with_profile(b, in_c, grid, grid, device, profile))
179 .collect::<Result<_>>()?;
180 Ok(())
181}
182
183pub fn apply_neck_native(
184 weights: &mut Sam3NeckWeights,
185 vision: &Sam3VisionOutput,
186) -> Result<Vec<Sam3FeatureLevel>> {
187 ensure!(
188 weights.loaded,
189 "SAM3 neck weights not loaded; call extract_neck_weights()"
190 );
191 let grid = vision.grid;
192 let dim = vision.dim;
193
194 let mut x_nchw = vec![0f32; dim * grid * grid];
196 for y in 0..grid {
197 for xc in 0..grid {
198 for c in 0..dim {
199 x_nchw[c * grid * grid + y * grid + xc] = vision.tokens[(y * grid + xc) * dim + c];
200 }
201 }
202 }
203
204 let mut levels = Vec::with_capacity(weights.branches.len());
205 if weights.ir.len() == weights.branches.len() {
206 for compiled in &mut weights.ir {
207 let features = compiled.run(&x_nchw, dim, grid, grid)?;
208 let pos = position_encoding_sine_sam3(SAM3_DET_DIM, compiled.out_h, compiled.out_w);
209 levels.push(Sam3FeatureLevel {
210 features,
211 pos,
212 h: compiled.out_h,
213 w: compiled.out_w,
214 channels: SAM3_DET_DIM,
215 });
216 }
217 } else {
218 for branch in &weights.branches {
219 let level = apply_branch_host(branch, &x_nchw, dim, grid, grid)?;
220 levels.push(level);
221 }
222 }
223 Ok(levels)
224}
225
226fn apply_branch_host(
227 branch: &Sam3NeckBranch,
228 x: &[f32],
229 in_c: usize,
230 h: usize,
231 w: usize,
232) -> Result<Sam3FeatureLevel> {
233 let mut cur = x.to_vec();
234 let mut cur_c = in_c;
235 let mut cur_h = h;
236 let mut cur_w = w;
237
238 if (branch.scale - 4.0).abs() < 1e-6 {
239 let dw0 = branch.dconv0_w.as_ref().unwrap();
240 let db0 = branch.dconv0_b.as_ref().unwrap();
241 cur = conv_transpose2d_stride2_k2(&cur, cur_c, 512, cur_h, cur_w, dw0, db0);
242 cur_c = 512;
243 cur_h *= 2;
244 cur_w *= 2;
245 gelu_inplace(&mut cur);
246 let dw1 = branch.dconv1_w.as_ref().unwrap();
247 let db1 = branch.dconv1_b.as_ref().unwrap();
248 cur = conv_transpose2d_stride2_k2(&cur, cur_c, 256, cur_h, cur_w, dw1, db1);
249 cur_c = 256;
250 cur_h *= 2;
251 cur_w *= 2;
252 } else if (branch.scale - 2.0).abs() < 1e-6 {
253 let dw = branch.dconv0_w.as_ref().unwrap();
254 let db = branch.dconv0_b.as_ref().unwrap();
255 cur = conv_transpose2d_stride2_k2(&cur, cur_c, 512, cur_h, cur_w, dw, db);
256 cur_c = 512;
257 cur_h *= 2;
258 cur_w *= 2;
259 } else if (branch.scale - 0.5).abs() < 1e-6 {
260 cur = maxpool2x2_stride2(&cur, cur_c, cur_h, cur_w);
261 cur_h /= 2;
262 cur_w /= 2;
263 }
265 ensure!(cur_c == branch.c1x1_in, "branch input channels mismatch");
266
267 cur = conv2d_1x1(
269 &cur,
270 cur_c,
271 SAM3_DET_DIM,
272 cur_h,
273 cur_w,
274 &branch.c1x1_w,
275 &branch.c1x1_b,
276 );
277 cur_c = SAM3_DET_DIM;
278
279 cur = conv2d_3x3_pad1(&cur, cur_c, cur_h, cur_w, &branch.c3x3_w, &branch.c3x3_b);
281
282 let pos = position_encoding_sine_sam3(SAM3_DET_DIM, cur_h, cur_w);
283 Ok(Sam3FeatureLevel {
284 features: cur,
285 pos,
286 h: cur_h,
287 w: cur_w,
288 channels: cur_c,
289 })
290}
291
292fn gelu_inplace(x: &mut [f32]) {
293 let inv_sqrt2 = 1.0f32 / std::f32::consts::SQRT_2;
297 for v in x.iter_mut() {
298 *v = 0.5 * *v * (1.0 + erf_approx(*v * inv_sqrt2));
299 }
300}
301
302fn erf_approx(x: f32) -> f32 {
303 let sign = if x < 0.0 { -1.0f32 } else { 1.0 };
305 let ax = x.abs();
306 let p = 0.3275911f32;
307 let a1 = 0.2548296f32;
308 let a2 = -0.2844967f32;
309 let a3 = 1.4214138f32;
310 let a4 = -1.4531521f32;
311 let a5 = 1.0614054f32;
312 let t = 1.0 / (1.0 + p * ax);
313 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-ax * ax).exp();
314 sign * y
315}
316
317fn maxpool2x2_stride2(input: &[f32], c: usize, h: usize, w: usize) -> Vec<f32> {
318 let out_h = h / 2;
319 let out_w = w / 2;
320 let mut out = vec![0f32; c * out_h * out_w];
321 for cc in 0..c {
322 let inp = &input[cc * h * w..(cc + 1) * h * w];
323 let oup = &mut out[cc * out_h * out_w..(cc + 1) * out_h * out_w];
324 for oy in 0..out_h {
325 for ox in 0..out_w {
326 let iy = oy * 2;
327 let ix = ox * 2;
328 let a = inp[iy * w + ix];
329 let b = inp[iy * w + ix + 1];
330 let cv = inp[(iy + 1) * w + ix];
331 let d = inp[(iy + 1) * w + ix + 1];
332 oup[oy * out_w + ox] = a.max(b).max(cv).max(d);
333 }
334 }
335 }
336 out
337}
338
339fn conv2d_1x1(
340 input: &[f32],
341 in_c: usize,
342 out_c: usize,
343 h: usize,
344 w: usize,
345 weight: &[f32], bias: &[f32],
347) -> Vec<f32> {
348 let n = h * w;
352 let mut out = vec![0f32; out_c * n];
353 rlx_cpu::blas::sgemm(weight, input, &mut out, out_c, in_c, n);
354 for oc in 0..out_c {
356 let b = bias[oc];
357 let row = &mut out[oc * n..(oc + 1) * n];
358 for v in row {
359 *v += b;
360 }
361 }
362 out
363}
364
365fn conv2d_3x3_pad1(
366 input: &[f32],
367 c: usize,
368 h: usize,
369 w: usize,
370 weight: &[f32], bias: &[f32],
372) -> Vec<f32> {
373 let mut out = vec![0f32; c * h * w];
374 for oc in 0..c {
375 let b = bias[oc];
376 let oup = &mut out[oc * h * w..(oc + 1) * h * w];
377 for v in oup.iter_mut() {
378 *v = b;
379 }
380 }
381 for oc in 0..c {
382 for ic in 0..c {
383 let w_oi = &weight[((oc * c + ic) * 9)..((oc * c + ic) * 9 + 9)];
384 let inp = &input[ic * h * w..(ic + 1) * h * w];
385 let oup = &mut out[oc * h * w..(oc + 1) * h * w];
386 for oy in 0..h {
387 for ox in 0..w {
388 let mut acc = 0.0f32;
389 for ky in 0..3 {
390 let iy = oy as isize + ky as isize - 1;
391 if iy < 0 || iy >= h as isize {
392 continue;
393 }
394 for kx in 0..3 {
395 let ix = ox as isize + kx as isize - 1;
396 if ix < 0 || ix >= w as isize {
397 continue;
398 }
399 acc += inp[iy as usize * w + ix as usize] * w_oi[ky * 3 + kx];
400 }
401 }
402 oup[oy * w + ox] += acc;
403 }
404 }
405 }
406 }
407 out
408}
409
410fn conv_transpose2d_stride2_k2(
411 input: &[f32],
412 in_c: usize,
413 out_c: usize,
414 h: usize,
415 w: usize,
416 weight: &[f32], bias: &[f32],
418) -> Vec<f32> {
419 let out_h = h * 2;
420 let out_w = w * 2;
421 let mut out = vec![0f32; out_c * out_h * out_w];
422 for oc in 0..out_c {
423 let b = bias[oc];
424 let plane = &mut out[oc * out_h * out_w..(oc + 1) * out_h * out_w];
425 for v in plane.iter_mut() {
426 *v = b;
427 }
428 }
429 for ic in 0..in_c {
430 for iy in 0..h {
431 for ix in 0..w {
432 let v = input[ic * h * w + iy * w + ix];
433 if v == 0.0 {
434 continue;
435 }
436 for ky in 0..2 {
437 let oy = iy * 2 + ky;
438 for kx in 0..2 {
439 let ox = ix * 2 + kx;
440 for oc in 0..out_c {
441 let w_idx = ((ic * out_c + oc) * 2 + ky) * 2 + kx;
442 out[oc * out_h * out_w + oy * out_w + ox] += v * weight[w_idx];
443 }
444 }
445 }
446 }
447 }
448 }
449 out
450}
451
452pub fn position_encoding_sine_sam3(d_model: usize, h: usize, w: usize) -> Vec<f32> {
456 assert!(d_model.is_multiple_of(2), "d_model must be even");
457 let num_pos_feats = d_model / 2;
458 let scale = 2.0 * std::f32::consts::PI;
459 let eps = 1e-6f32;
460 let temperature = 10000.0f32;
461
462 let mut dim_t = vec![0.0f32; num_pos_feats];
463 for i in 0..num_pos_feats {
464 let exp = 2.0 * ((i / 2) as f32) / num_pos_feats as f32;
465 dim_t[i] = temperature.powf(exp);
466 }
467
468 let mut out = vec![0.0f32; d_model * h * w];
469 let y_denom = h as f32 + eps; let x_denom = w as f32 + eps;
471
472 for y in 0..h {
473 let y_norm = ((y + 1) as f32) / y_denom * scale;
474 for x in 0..w {
475 let x_norm = ((x + 1) as f32) / x_denom * scale;
476 for i in 0..num_pos_feats {
478 let py = y_norm / dim_t[i];
479 let v = if i % 2 == 0 { py.sin() } else { py.cos() };
480 out[i * h * w + y * w + x] = v;
481 }
482 for i in 0..num_pos_feats {
483 let px = x_norm / dim_t[i];
484 let v = if i % 2 == 0 { px.sin() } else { px.cos() };
485 out[(num_pos_feats + i) * h * w + y * w + x] = v;
486 }
487 }
488 }
489 out
490}