1use super::config::Sam2DecoderConfig;
47use super::transformer::{
48 Sam2TwoWayTransformerWeights, add_inplace, extract_two_way_transformer_weights, linear,
49 two_way_transformer_forward,
50};
51use super::upscale_ir::Sam2MaskUpscaleCompiled;
52use anyhow::{Result, ensure};
53use rlx_core::weight_map::WeightMap;
54use rlx_sam_ir::mask_hyper_matmul_ir::MaskHyperMatmulCompiled;
55use rlx_sam_ir::mlp_relu_ir::MlpReluCompiled;
56use rlx_sam_ir::twoway_transformer_ir::TwoWayTransformerCompiled;
57
58pub struct Sam2MaskDecoderWeights {
59 pub iou_token: Vec<f32>, pub mask_tokens: Vec<f32>, pub obj_score_token: Option<Vec<f32>>,
63 pub transformer: Sam2TwoWayTransformerWeights,
64
65 pub upscale_conv1_w: Vec<f32>,
67 pub upscale_conv1_b: Vec<f32>,
68 pub upscale_ln_g: Vec<f32>,
69 pub upscale_ln_b: Vec<f32>,
70 pub upscale_conv2_w: Vec<f32>,
72 pub upscale_conv2_b: Vec<f32>,
73
74 pub conv_s0_w: Option<Vec<f32>>,
78 pub conv_s0_b: Option<Vec<f32>>,
79 pub conv_s1_w: Option<Vec<f32>>,
80 pub conv_s1_b: Option<Vec<f32>>,
81
82 pub hyper_mlps: Vec<Sam2HypernetMlp>,
85
86 pub iou_head: Sam2HypernetMlp,
89 pub iou_use_sigmoid: bool,
91
92 pub obj_score_head: Option<Sam2HypernetMlp>,
95
96 pub obj_ptr_proj: Option<Sam2HypernetMlp>,
100
101 pub transformer_dim: usize,
102 pub num_mask_tokens: usize,
103 pub use_high_res_features: bool,
104 pub pred_obj_scores: bool,
105 pub use_multimask_token_for_obj_ptr: bool,
106 pub dynamic_multimask_via_stability: bool,
107 pub dynamic_multimask_stability_delta: f32,
108 pub dynamic_multimask_stability_thresh: f32,
109}
110
111pub struct Sam2HypernetMlp {
112 pub layers: Vec<Sam2MlpLayer>,
113 pub sigmoid_output: bool,
114}
115
116pub struct Sam2MlpLayer {
117 pub w: Vec<f32>,
118 pub b: Vec<f32>,
119 pub in_d: usize,
120 pub out_d: usize,
121}
122
123pub fn extract_mask_decoder_weights(
124 weights: &mut WeightMap,
125 cfg: &Sam2DecoderConfig,
126) -> Result<Sam2MaskDecoderWeights> {
127 let transformer_dim = cfg.transformer_dim;
128 let num_mask_tokens = cfg.num_mask_tokens;
129
130 let (iou_token, sh) = weights.take("sam_mask_decoder.iou_token.weight")?;
131 ensure!(
132 sh == vec![1, transformer_dim],
133 "iou_token shape {sh:?} not [1, {transformer_dim}]"
134 );
135 let (mask_tokens, sh) = weights.take("sam_mask_decoder.mask_tokens.weight")?;
136 ensure!(
137 sh == vec![num_mask_tokens, transformer_dim],
138 "mask_tokens shape {sh:?} not [{num_mask_tokens}, {transformer_dim}]"
139 );
140
141 let obj_score_token = if cfg.pred_obj_scores {
142 let (data, sh) = weights.take("sam_mask_decoder.obj_score_token.weight")?;
143 ensure!(
144 sh == vec![1, transformer_dim],
145 "obj_score_token shape {sh:?} not [1, {transformer_dim}]"
146 );
147 Some(data)
148 } else {
149 None
150 };
151
152 let q4 = transformer_dim / 4;
154 let q8 = transformer_dim / 8;
155 let (upscale_conv1_w, sh) = weights.take("sam_mask_decoder.output_upscaling.0.weight")?;
156 ensure!(
157 sh == vec![transformer_dim, q4, 2, 2],
158 "output_upscaling.0.weight shape {sh:?} not [{transformer_dim}, {q4}, 2, 2]"
159 );
160 let (upscale_conv1_b, _) = weights.take("sam_mask_decoder.output_upscaling.0.bias")?;
161 let (upscale_ln_g, _) = weights.take("sam_mask_decoder.output_upscaling.1.weight")?;
162 let (upscale_ln_b, _) = weights.take("sam_mask_decoder.output_upscaling.1.bias")?;
163 let (upscale_conv2_w, sh) = weights.take("sam_mask_decoder.output_upscaling.3.weight")?;
164 ensure!(
165 sh == vec![q4, q8, 2, 2],
166 "output_upscaling.3.weight shape {sh:?} not [{q4}, {q8}, 2, 2]"
167 );
168 let (upscale_conv2_b, _) = weights.take("sam_mask_decoder.output_upscaling.3.bias")?;
169
170 let (conv_s0_w, conv_s0_b, conv_s1_w, conv_s1_b) = if cfg.use_high_res_features {
172 let (s0w, sh) = weights.take("sam_mask_decoder.conv_s0.weight")?;
173 ensure!(
174 sh == vec![q8, transformer_dim, 1, 1],
175 "conv_s0.weight shape {sh:?} not [{q8}, {transformer_dim}, 1, 1]"
176 );
177 let (s0b, _) = weights.take("sam_mask_decoder.conv_s0.bias")?;
178 let (s1w, sh) = weights.take("sam_mask_decoder.conv_s1.weight")?;
179 ensure!(
180 sh == vec![q4, transformer_dim, 1, 1],
181 "conv_s1.weight shape {sh:?} not [{q4}, {transformer_dim}, 1, 1]"
182 );
183 let (s1b, _) = weights.take("sam_mask_decoder.conv_s1.bias")?;
184 (Some(s0w), Some(s0b), Some(s1w), Some(s1b))
185 } else {
186 (None, None, None, None)
187 };
188
189 let mut hyper_mlps = Vec::with_capacity(num_mask_tokens);
191 for i in 0..num_mask_tokens {
192 let mlp = extract_mlp(
193 weights,
194 &format!("sam_mask_decoder.output_hypernetworks_mlps.{i}"),
195 transformer_dim,
196 transformer_dim,
197 q8,
198 3,
199 false,
200 )?;
201 hyper_mlps.push(mlp);
202 }
203
204 let iou_head = extract_mlp(
206 weights,
207 "sam_mask_decoder.iou_prediction_head",
208 transformer_dim,
209 cfg.iou_head_hidden_dim,
210 num_mask_tokens,
211 cfg.iou_head_depth,
212 cfg.iou_prediction_use_sigmoid,
213 )?;
214
215 let obj_score_head = if cfg.pred_obj_scores {
218 if cfg.pred_obj_scores_mlp {
219 Some(extract_mlp(
220 weights,
221 "sam_mask_decoder.pred_obj_score_head",
222 transformer_dim,
223 transformer_dim,
224 1,
225 3,
226 false,
227 )?)
228 } else {
229 let (w, sh) = weights.take("sam_mask_decoder.pred_obj_score_head.weight")?;
230 ensure!(
231 sh == vec![1, transformer_dim],
232 "pred_obj_score_head.weight shape {sh:?} not [1, {transformer_dim}]"
233 );
234 let (b, _) = weights.take("sam_mask_decoder.pred_obj_score_head.bias")?;
235 Some(Sam2HypernetMlp {
236 layers: vec![Sam2MlpLayer {
237 w,
238 b,
239 in_d: transformer_dim,
240 out_d: 1,
241 }],
242 sigmoid_output: false,
243 })
244 }
245 } else {
246 None
247 };
248
249 let obj_ptr_proj = if cfg.use_object_pointer {
254 if cfg.use_mlp_for_obj_ptr_proj {
255 Some(extract_mlp(
256 weights,
257 "obj_ptr_proj",
258 transformer_dim,
259 transformer_dim,
260 transformer_dim,
261 3,
262 false,
263 )?)
264 } else {
265 let (w, sh) = weights.take("obj_ptr_proj.weight")?;
266 ensure!(
267 sh == vec![transformer_dim, transformer_dim],
268 "obj_ptr_proj.weight shape {sh:?} not [{transformer_dim}, {transformer_dim}]"
269 );
270 let (b, _) = weights.take("obj_ptr_proj.bias")?;
271 Some(Sam2HypernetMlp {
272 layers: vec![Sam2MlpLayer {
273 w,
274 b,
275 in_d: transformer_dim,
276 out_d: transformer_dim,
277 }],
278 sigmoid_output: false,
279 })
280 }
281 } else {
282 None
283 };
284
285 let transformer = extract_two_way_transformer_weights(
286 weights,
287 transformer_dim,
288 cfg.transformer_depth,
289 cfg.transformer_num_heads,
290 cfg.transformer_mlp_dim,
291 )?;
292
293 Ok(Sam2MaskDecoderWeights {
294 iou_token,
295 mask_tokens,
296 obj_score_token,
297 transformer,
298 upscale_conv1_w,
299 upscale_conv1_b,
300 upscale_ln_g,
301 upscale_ln_b,
302 upscale_conv2_w,
303 upscale_conv2_b,
304 conv_s0_w,
305 conv_s0_b,
306 conv_s1_w,
307 conv_s1_b,
308 hyper_mlps,
309 iou_head,
310 iou_use_sigmoid: cfg.iou_prediction_use_sigmoid,
311 obj_score_head,
312 obj_ptr_proj,
313 transformer_dim,
314 num_mask_tokens,
315 use_high_res_features: cfg.use_high_res_features,
316 pred_obj_scores: cfg.pred_obj_scores,
317 use_multimask_token_for_obj_ptr: cfg.use_multimask_token_for_obj_ptr,
318 dynamic_multimask_via_stability: cfg.dynamic_multimask_via_stability,
319 dynamic_multimask_stability_delta: cfg.dynamic_multimask_stability_delta,
320 dynamic_multimask_stability_thresh: cfg.dynamic_multimask_stability_thresh,
321 })
322}
323
324fn extract_mlp(
325 weights: &mut WeightMap,
326 prefix: &str,
327 input_dim: usize,
328 hidden_dim: usize,
329 output_dim: usize,
330 num_layers: usize,
331 sigmoid_output: bool,
332) -> Result<Sam2HypernetMlp> {
333 let mut layers = Vec::with_capacity(num_layers);
334 for i in 0..num_layers {
335 let in_d = if i == 0 { input_dim } else { hidden_dim };
336 let out_d = if i + 1 == num_layers {
337 output_dim
338 } else {
339 hidden_dim
340 };
341 let (w, sh) = weights.take(&format!("{prefix}.layers.{i}.weight"))?;
342 ensure!(
343 sh == vec![out_d, in_d],
344 "{prefix}.layers.{i}.weight shape {sh:?} not [{out_d}, {in_d}]"
345 );
346 let (b, _) = weights.take(&format!("{prefix}.layers.{i}.bias"))?;
347 layers.push(Sam2MlpLayer { w, b, in_d, out_d });
348 }
349 Ok(Sam2HypernetMlp {
350 layers,
351 sigmoid_output,
352 })
353}
354
355pub fn mlp_forward(mlp: &Sam2HypernetMlp, x: &[f32], rows: usize) -> Vec<f32> {
358 let mut cur = x.to_vec();
359 let n = mlp.layers.len();
360 for (i, layer) in mlp.layers.iter().enumerate() {
361 cur = linear(&cur, &layer.w, &layer.b, rows, layer.in_d, layer.out_d);
362 if i + 1 < n {
363 for v in cur.iter_mut() {
364 if *v < 0.0 {
365 *v = 0.0;
366 }
367 }
368 }
369 }
370 if mlp.sigmoid_output {
371 for v in cur.iter_mut() {
372 *v = 1.0 / (1.0 + (-*v).exp());
373 }
374 }
375 cur
376}
377
378pub struct Sam2MaskDecoderOutput {
380 pub masks: Vec<f32>,
383 pub iou_pred: Vec<f32>, pub num_masks: usize,
385 pub h_out: usize,
386 pub w_out: usize,
387 pub sam_tokens_out: Vec<f32>,
391 pub num_ptr_tokens: usize,
392 pub object_score_logits: Vec<f32>,
396 pub object_pointer: Option<Vec<f32>>,
399}
400
401#[allow(clippy::too_many_arguments)]
415pub fn mask_decoder_forward(
416 w: &Sam2MaskDecoderWeights,
417 upscale: &mut Sam2MaskUpscaleCompiled,
418 hyper_matmul: Option<&mut MaskHyperMatmulCompiled>,
419 hyper_mlps_ir: Option<&mut [MlpReluCompiled]>,
420 iou_head_ir: Option<&mut MlpReluCompiled>,
421 obj_score_head_ir: Option<&mut MlpReluCompiled>,
422 obj_ptr_proj_ir: Option<&mut MlpReluCompiled>,
423 tw_ir: Option<&mut TwoWayTransformerCompiled>,
424 image_embeddings: &[f32],
425 image_pe: &[f32],
426 sparse_prompt_embeddings: &[f32],
427 num_sparse_tokens: usize,
428 dense_prompt_embeddings: &[f32],
429 high_res_features: Option<(&[f32], &[f32])>,
430 multimask_output: bool,
431 grid: usize,
432) -> Result<Sam2MaskDecoderOutput> {
433 let e = w.transformer_dim;
434 let nm = w.num_mask_tokens;
435 let g = grid;
436 ensure!(
437 image_embeddings.len() == e * g * g,
438 "image_embeddings len {} ≠ E·g·g ({e}·{g}·{g})",
439 image_embeddings.len()
440 );
441 ensure!(
442 image_pe.len() == e * g * g,
443 "image_pe len {} ≠ E·g·g",
444 image_pe.len()
445 );
446 ensure!(
447 dense_prompt_embeddings.len() == e * g * g,
448 "dense_prompt_embeddings len {} ≠ E·g·g",
449 dense_prompt_embeddings.len()
450 );
451 ensure!(
452 sparse_prompt_embeddings.len() == num_sparse_tokens * e,
453 "sparse_prompt_embeddings len {} ≠ num_sparse·E ({num_sparse_tokens}·{e})",
454 sparse_prompt_embeddings.len()
455 );
456 if w.use_high_res_features {
457 let (s0, s1) = high_res_features.ok_or_else(|| {
458 anyhow::anyhow!("use_high_res_features=true requires (feat_s0, feat_s1)")
459 })?;
460 ensure!(
461 s0.len() == e * (4 * g) * (4 * g),
462 "feat_s0 len {} ≠ E·4g·4g ({e}·{}·{})",
463 s0.len(),
464 4 * g,
465 4 * g
466 );
467 ensure!(
468 s1.len() == e * (2 * g) * (2 * g),
469 "feat_s1 len {} ≠ E·2g·2g ({e}·{}·{})",
470 s1.len(),
471 2 * g,
472 2 * g
473 );
474 }
475
476 let s = if w.obj_score_token.is_some() { 1 } else { 0 };
478 let n_out_tokens = s + 1 + nm;
479 let q_n = n_out_tokens + num_sparse_tokens;
480 let mut tokens = Vec::with_capacity(q_n * e);
481 if let Some(obj) = &w.obj_score_token {
482 tokens.extend_from_slice(obj);
483 }
484 tokens.extend_from_slice(&w.iou_token);
485 tokens.extend_from_slice(&w.mask_tokens);
486 tokens.extend_from_slice(sparse_prompt_embeddings);
487
488 let mut src = image_embeddings.to_vec();
490 for i in 0..src.len() {
491 src[i] += dense_prompt_embeddings[i];
492 }
493 let pos_src = image_pe.to_vec();
494
495 let k_n = g * g;
497 let (hs, src_post) = if let Some(tw) = tw_ir {
498 if tw.masked && q_n <= tw.max_q_n && tw.k_n == k_n {
499 tw.run_nchw_masked(&tokens, q_n, &src, &pos_src, g)?
500 } else if !tw.masked && q_n == tw.max_q_n && tw.k_n == k_n {
501 tw.run_nchw(&tokens, &src, &pos_src, g)?
502 } else {
503 two_way_transformer_forward(&w.transformer, &src, &pos_src, &tokens, 1, e, g, g, q_n)
504 }
505 } else {
506 two_way_transformer_forward(&w.transformer, &src, &pos_src, &tokens, 1, e, g, g, q_n)
507 };
508
509 let obj_score_logits_pre = if let Some(ir) = obj_score_head_ir {
510 ir.run(&hs[..e], 1)?
511 } else if let Some(head) = &w.obj_score_head {
512 let token = &hs[..e];
513 mlp_forward(head, token, 1)
514 } else {
515 vec![10.0]
517 };
518
519 let iou_token_out: Vec<f32> = hs[s * e..(s + 1) * e].to_vec();
520 let mask_tokens_out = hs[(s + 1) * e..(s + 1 + nm) * e].to_vec();
521
522 let mut src_nchw = vec![0f32; e * g * g];
524 for ss in 0..g * g {
525 for c in 0..e {
526 src_nchw[c * g * g + ss] = src_post[ss * e + c];
527 }
528 }
529
530 let q8 = e / 8;
532 let h2 = g * 4;
533 let w2 = g * 4;
534 let (feat_s0, feat_s1) = high_res_features.unwrap_or((&[] as &[f32], &[] as &[f32]));
535 let up2 = upscale.run(&src_nchw, feat_s1, feat_s0, g)?;
536
537 let mut hyper_in = vec![0f32; nm * q8];
539 if let Some(mlps) = hyper_mlps_ir {
540 ensure!(
541 mlps.len() == nm,
542 "hyper_mlps_ir len {} ≠ num_mask_tokens {}",
543 mlps.len(),
544 nm
545 );
546 for i in 0..nm {
547 let token = &mask_tokens_out[i * e..(i + 1) * e];
548 let h = mlps[i].run(token, 1)?;
549 hyper_in[i * q8..(i + 1) * q8].copy_from_slice(&h);
550 }
551 } else {
552 for i in 0..nm {
553 let token = &mask_tokens_out[i * e..(i + 1) * e];
554 let h = mlp_forward(&w.hyper_mlps[i], token, 1);
555 hyper_in[i * q8..(i + 1) * q8].copy_from_slice(&h);
556 }
557 }
558 let spat = h2 * w2;
559 let mut masks_all = vec![0f32; nm * spat];
560 if let Some(hm) = hyper_matmul {
561 hm.run(&hyper_in, &up2, &mut masks_all)?;
562 } else {
563 rlx_cpu::blas::sgemm_auto(&hyper_in, &up2, &mut masks_all, nm, q8, spat);
564 }
565
566 let iou_pred_all = if let Some(head) = iou_head_ir {
568 head.run(&iou_token_out, 1)?
569 } else {
570 mlp_forward(&w.iou_head, &iou_token_out, 1)
571 };
572
573 let (masks, iou_pred, num_masks, ptr_indices): (Vec<f32>, Vec<f32>, usize, Vec<usize>) =
575 if multimask_output {
576 let masks = masks_all[spat..].to_vec();
578 let iou = iou_pred_all[1..].to_vec();
579 let ptr = if w.use_multimask_token_for_obj_ptr {
580 (1..nm).collect()
581 } else {
582 vec![0]
583 };
584 (masks, iou, nm - 1, ptr)
585 } else if w.dynamic_multimask_via_stability {
586 dynamic_multimask_via_stability(
587 &masks_all,
588 &iou_pred_all,
589 nm,
590 spat,
591 w.dynamic_multimask_stability_delta,
592 w.dynamic_multimask_stability_thresh,
593 )
594 } else {
595 let masks = masks_all[..spat].to_vec();
596 let iou = iou_pred_all[..1].to_vec();
597 (masks, iou, 1, vec![0])
598 };
599
600 let num_ptr_tokens = ptr_indices.len();
601 let mut sam_tokens_out = Vec::with_capacity(num_ptr_tokens * e);
602 for &pi in &ptr_indices {
603 sam_tokens_out.extend_from_slice(&mask_tokens_out[pi * e..(pi + 1) * e]);
604 }
605
606 let object_pointer = if let Some(ir) = obj_ptr_proj_ir {
607 if ir.compiled_rows() == num_ptr_tokens {
608 Some(ir.run(&sam_tokens_out, num_ptr_tokens)?)
609 } else {
610 w.obj_ptr_proj
611 .as_ref()
612 .map(|proj| mlp_forward(proj, &sam_tokens_out, num_ptr_tokens))
613 }
614 } else {
615 w.obj_ptr_proj
616 .as_ref()
617 .map(|proj| mlp_forward(proj, &sam_tokens_out, num_ptr_tokens))
618 };
619
620 Ok(Sam2MaskDecoderOutput {
621 masks,
622 iou_pred,
623 num_masks,
624 h_out: h2,
625 w_out: w2,
626 sam_tokens_out,
627 num_ptr_tokens,
628 object_score_logits: obj_score_logits_pre,
629 object_pointer,
630 })
631}
632
633fn dynamic_multimask_via_stability(
637 masks_all: &[f32],
638 iou_pred_all: &[f32],
639 _nm: usize,
640 spat: usize,
641 delta: f32,
642 thresh: f32,
643) -> (Vec<f32>, Vec<f32>, usize, Vec<usize>) {
644 let mm_masks = &masks_all[spat..];
646 let mm_iou = &iou_pred_all[1..];
647 let best = mm_iou
649 .iter()
650 .enumerate()
651 .fold((0usize, f32::NEG_INFINITY), |(bi, bv), (i, &v)| {
652 if v > bv { (i, v) } else { (bi, bv) }
653 })
654 .0;
655
656 let single_mask = &masks_all[..spat];
658 let stability = mask_stability_score(single_mask, delta);
659 if stability >= thresh {
660 (single_mask.to_vec(), iou_pred_all[..1].to_vec(), 1, vec![0])
662 } else {
663 let masks = mm_masks[best * spat..(best + 1) * spat].to_vec();
665 let iou = vec![mm_iou[best]];
666 (masks, iou, 1, vec![best + 1])
668 }
669}
670
671fn mask_stability_score(mask_logits: &[f32], delta: f32) -> f32 {
674 let mut hi = 0u32;
675 let mut lo = 0u32;
676 for &v in mask_logits {
677 if v > delta {
678 hi += 1;
679 }
680 if v > -delta {
681 lo += 1;
682 }
683 }
684 if lo == 0 { 1.0 } else { hi as f32 / lo as f32 }
685}
686
687#[allow(dead_code)]
688fn _silence_add_inplace(x: &mut [f32], y: &[f32]) {
689 add_inplace(x, y);
690}