Skip to main content

rlx_sam2/
mask_decoder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! SAM 2 mask decoder — host-side.
17//!
18//! Mirrors `sam2/modeling/sam/mask_decoder.py::MaskDecoder` exactly.
19//! Five differences vs the SAM v1 mask decoder:
20//!
21//!   1. **Object-score token + head.** When `pred_obj_scores=True`
22//!      (the default for SAM 2), an extra `obj_score_token` is
23//!      prepended to the iou+mask tokens; the decoder's first output
24//!      slot becomes `obj_score_logits = pred_obj_score_head(hs[:, 0])`.
25//!   2. **High-res features.** The upscaler eats the FpnNeck's
26//!      stride-4 and stride-8 levels via two 1×1 lateral convs
27//!      (`conv_s0`, `conv_s1`), additively fused into the upscaling
28//!      stack between the two `ConvTranspose2d` layers.
29//!   3. **`use_multimask_token_for_obj_ptr=True`.** When multimask
30//!      output is selected, the object-pointer projection consumes
31//!      `mask_tokens_out[:, 1:]` (the three multimask tokens) instead
32//!      of `mask_tokens_out[:, 0:1]` (the single).
33//!   4. **Dynamic multimask via stability** (`dynamic_multimask_via_
34//!      stability=True` in some configs) — if multimask_output is
35//!      False but the model thinks the single token's stability is
36//!      below a threshold, fall back to the best of the multimask
37//!      outputs. Implemented per the reference.
38//!   5. **Object-pointer projection.** A small MLP that turns the
39//!      selected mask token(s) into the pointer fed to the memory
40//!      attention layer (Phase 3 path). Weights live here on the
41//!      decoder side.
42//!
43//! Weight key prefix is `sam_mask_decoder.*` (SAM 2 nests the mask
44//! decoder under `sam_mask_decoder` in the published checkpoints).
45
46use super::config::Sam2DecoderConfig;
47use super::transformer::{
48    Sam2TwoWayTransformerWeights, add_inplace, extract_two_way_transformer_weights, linear,
49    two_way_transformer_forward,
50};
51use super::upscale_ir::Sam2MaskUpscaleCompiled;
52use anyhow::{Result, ensure};
53use rlx_core::weight_map::WeightMap;
54use rlx_sam_ir::mask_hyper_matmul_ir::MaskHyperMatmulCompiled;
55use rlx_sam_ir::mlp_relu_ir::MlpReluCompiled;
56use rlx_sam_ir::twoway_transformer_ir::TwoWayTransformerCompiled;
57
58pub struct Sam2MaskDecoderWeights {
59    pub iou_token: Vec<f32>,   // [1, transformer_dim]
60    pub mask_tokens: Vec<f32>, // [num_mask_tokens, transformer_dim]
61    /// Optional object-score token, populated when `pred_obj_scores=true`.
62    pub obj_score_token: Option<Vec<f32>>,
63    pub transformer: Sam2TwoWayTransformerWeights,
64
65    /// ConvTranspose2d in=transformer_dim, out=transformer_dim/4, k=2, s=2.
66    pub upscale_conv1_w: Vec<f32>,
67    pub upscale_conv1_b: Vec<f32>,
68    pub upscale_ln_g: Vec<f32>,
69    pub upscale_ln_b: Vec<f32>,
70    /// ConvTranspose2d in=transformer_dim/4, out=transformer_dim/8.
71    pub upscale_conv2_w: Vec<f32>,
72    pub upscale_conv2_b: Vec<f32>,
73
74    /// Optional high-res fusion 1×1 convs.
75    ///   - `conv_s0`: stride-4 features → transformer_dim/8 channels
76    ///   - `conv_s1`: stride-8 features → transformer_dim/4 channels
77    pub conv_s0_w: Option<Vec<f32>>,
78    pub conv_s0_b: Option<Vec<f32>>,
79    pub conv_s1_w: Option<Vec<f32>>,
80    pub conv_s1_b: Option<Vec<f32>>,
81
82    /// `num_mask_tokens` × 3-layer ReLU MLPs producing
83    /// `transformer_dim/8` per mask token.
84    pub hyper_mlps: Vec<Sam2HypernetMlp>,
85
86    /// IoU prediction head: 3-layer ReLU MLP `transformer_dim →
87    /// iou_head_hidden_dim → iou_head_hidden_dim → num_mask_tokens`.
88    pub iou_head: Sam2HypernetMlp,
89    /// `iou_prediction_use_sigmoid` flag.
90    pub iou_use_sigmoid: bool,
91
92    /// Optional object-score prediction head (3-layer MLP when the
93    /// `pred_obj_scores_mlp` flag is set; otherwise a plain Linear).
94    pub obj_score_head: Option<Sam2HypernetMlp>,
95
96    /// Optional object-pointer projection MLP. Reference shape:
97    /// `MLP(transformer_dim, transformer_dim, transformer_dim, 3)` if
98    /// `use_mlp_for_obj_ptr_proj=True`, otherwise `Linear(...)`.
99    pub obj_ptr_proj: Option<Sam2HypernetMlp>,
100
101    pub transformer_dim: usize,
102    pub num_mask_tokens: usize,
103    pub use_high_res_features: bool,
104    pub pred_obj_scores: bool,
105    pub use_multimask_token_for_obj_ptr: bool,
106    pub dynamic_multimask_via_stability: bool,
107    pub dynamic_multimask_stability_delta: f32,
108    pub dynamic_multimask_stability_thresh: f32,
109}
110
111pub struct Sam2HypernetMlp {
112    pub layers: Vec<Sam2MlpLayer>,
113    pub sigmoid_output: bool,
114}
115
116pub struct Sam2MlpLayer {
117    pub w: Vec<f32>,
118    pub b: Vec<f32>,
119    pub in_d: usize,
120    pub out_d: usize,
121}
122
123pub fn extract_mask_decoder_weights(
124    weights: &mut WeightMap,
125    cfg: &Sam2DecoderConfig,
126) -> Result<Sam2MaskDecoderWeights> {
127    let transformer_dim = cfg.transformer_dim;
128    let num_mask_tokens = cfg.num_mask_tokens;
129
130    let (iou_token, sh) = weights.take("sam_mask_decoder.iou_token.weight")?;
131    ensure!(
132        sh == vec![1, transformer_dim],
133        "iou_token shape {sh:?} not [1, {transformer_dim}]"
134    );
135    let (mask_tokens, sh) = weights.take("sam_mask_decoder.mask_tokens.weight")?;
136    ensure!(
137        sh == vec![num_mask_tokens, transformer_dim],
138        "mask_tokens shape {sh:?} not [{num_mask_tokens}, {transformer_dim}]"
139    );
140
141    let obj_score_token = if cfg.pred_obj_scores {
142        let (data, sh) = weights.take("sam_mask_decoder.obj_score_token.weight")?;
143        ensure!(
144            sh == vec![1, transformer_dim],
145            "obj_score_token shape {sh:?} not [1, {transformer_dim}]"
146        );
147        Some(data)
148    } else {
149        None
150    };
151
152    // ConvTranspose2d weight convention in PyTorch: [in, out, kH, kW].
153    let q4 = transformer_dim / 4;
154    let q8 = transformer_dim / 8;
155    let (upscale_conv1_w, sh) = weights.take("sam_mask_decoder.output_upscaling.0.weight")?;
156    ensure!(
157        sh == vec![transformer_dim, q4, 2, 2],
158        "output_upscaling.0.weight shape {sh:?} not [{transformer_dim}, {q4}, 2, 2]"
159    );
160    let (upscale_conv1_b, _) = weights.take("sam_mask_decoder.output_upscaling.0.bias")?;
161    let (upscale_ln_g, _) = weights.take("sam_mask_decoder.output_upscaling.1.weight")?;
162    let (upscale_ln_b, _) = weights.take("sam_mask_decoder.output_upscaling.1.bias")?;
163    let (upscale_conv2_w, sh) = weights.take("sam_mask_decoder.output_upscaling.3.weight")?;
164    ensure!(
165        sh == vec![q4, q8, 2, 2],
166        "output_upscaling.3.weight shape {sh:?} not [{q4}, {q8}, 2, 2]"
167    );
168    let (upscale_conv2_b, _) = weights.take("sam_mask_decoder.output_upscaling.3.bias")?;
169
170    // High-res fusion convs (gated on `use_high_res_features`).
171    let (conv_s0_w, conv_s0_b, conv_s1_w, conv_s1_b) = if cfg.use_high_res_features {
172        let (s0w, sh) = weights.take("sam_mask_decoder.conv_s0.weight")?;
173        ensure!(
174            sh == vec![q8, transformer_dim, 1, 1],
175            "conv_s0.weight shape {sh:?} not [{q8}, {transformer_dim}, 1, 1]"
176        );
177        let (s0b, _) = weights.take("sam_mask_decoder.conv_s0.bias")?;
178        let (s1w, sh) = weights.take("sam_mask_decoder.conv_s1.weight")?;
179        ensure!(
180            sh == vec![q4, transformer_dim, 1, 1],
181            "conv_s1.weight shape {sh:?} not [{q4}, {transformer_dim}, 1, 1]"
182        );
183        let (s1b, _) = weights.take("sam_mask_decoder.conv_s1.bias")?;
184        (Some(s0w), Some(s0b), Some(s1w), Some(s1b))
185    } else {
186        (None, None, None, None)
187    };
188
189    // Hypernetwork MLPs.
190    let mut hyper_mlps = Vec::with_capacity(num_mask_tokens);
191    for i in 0..num_mask_tokens {
192        let mlp = extract_mlp(
193            weights,
194            &format!("sam_mask_decoder.output_hypernetworks_mlps.{i}"),
195            transformer_dim,
196            transformer_dim,
197            q8,
198            3,
199            false,
200        )?;
201        hyper_mlps.push(mlp);
202    }
203
204    // IoU prediction head.
205    let iou_head = extract_mlp(
206        weights,
207        "sam_mask_decoder.iou_prediction_head",
208        transformer_dim,
209        cfg.iou_head_hidden_dim,
210        num_mask_tokens,
211        cfg.iou_head_depth,
212        cfg.iou_prediction_use_sigmoid,
213    )?;
214
215    // Object-score head: 3-layer MLP when pred_obj_scores_mlp,
216    // else a plain Linear(d, 1).
217    let obj_score_head = if cfg.pred_obj_scores {
218        if cfg.pred_obj_scores_mlp {
219            Some(extract_mlp(
220                weights,
221                "sam_mask_decoder.pred_obj_score_head",
222                transformer_dim,
223                transformer_dim,
224                1,
225                3,
226                false,
227            )?)
228        } else {
229            let (w, sh) = weights.take("sam_mask_decoder.pred_obj_score_head.weight")?;
230            ensure!(
231                sh == vec![1, transformer_dim],
232                "pred_obj_score_head.weight shape {sh:?} not [1, {transformer_dim}]"
233            );
234            let (b, _) = weights.take("sam_mask_decoder.pred_obj_score_head.bias")?;
235            Some(Sam2HypernetMlp {
236                layers: vec![Sam2MlpLayer {
237                    w,
238                    b,
239                    in_d: transformer_dim,
240                    out_d: 1,
241                }],
242                sigmoid_output: false,
243            })
244        }
245    } else {
246        None
247    };
248
249    // Object-pointer projection. NB: lives at the *top level* of the
250    // SAM2Base module (sibling of `sam_mask_decoder`, not nested), so
251    // the checkpoint keys are `obj_ptr_proj.layers.{i}.weight` (not
252    // `sam_mask_decoder.obj_ptr_proj.…`). Same for `no_obj_ptr`.
253    let obj_ptr_proj = if cfg.use_object_pointer {
254        if cfg.use_mlp_for_obj_ptr_proj {
255            Some(extract_mlp(
256                weights,
257                "obj_ptr_proj",
258                transformer_dim,
259                transformer_dim,
260                transformer_dim,
261                3,
262                false,
263            )?)
264        } else {
265            let (w, sh) = weights.take("obj_ptr_proj.weight")?;
266            ensure!(
267                sh == vec![transformer_dim, transformer_dim],
268                "obj_ptr_proj.weight shape {sh:?} not [{transformer_dim}, {transformer_dim}]"
269            );
270            let (b, _) = weights.take("obj_ptr_proj.bias")?;
271            Some(Sam2HypernetMlp {
272                layers: vec![Sam2MlpLayer {
273                    w,
274                    b,
275                    in_d: transformer_dim,
276                    out_d: transformer_dim,
277                }],
278                sigmoid_output: false,
279            })
280        }
281    } else {
282        None
283    };
284
285    let transformer = extract_two_way_transformer_weights(
286        weights,
287        transformer_dim,
288        cfg.transformer_depth,
289        cfg.transformer_num_heads,
290        cfg.transformer_mlp_dim,
291    )?;
292
293    Ok(Sam2MaskDecoderWeights {
294        iou_token,
295        mask_tokens,
296        obj_score_token,
297        transformer,
298        upscale_conv1_w,
299        upscale_conv1_b,
300        upscale_ln_g,
301        upscale_ln_b,
302        upscale_conv2_w,
303        upscale_conv2_b,
304        conv_s0_w,
305        conv_s0_b,
306        conv_s1_w,
307        conv_s1_b,
308        hyper_mlps,
309        iou_head,
310        iou_use_sigmoid: cfg.iou_prediction_use_sigmoid,
311        obj_score_head,
312        obj_ptr_proj,
313        transformer_dim,
314        num_mask_tokens,
315        use_high_res_features: cfg.use_high_res_features,
316        pred_obj_scores: cfg.pred_obj_scores,
317        use_multimask_token_for_obj_ptr: cfg.use_multimask_token_for_obj_ptr,
318        dynamic_multimask_via_stability: cfg.dynamic_multimask_via_stability,
319        dynamic_multimask_stability_delta: cfg.dynamic_multimask_stability_delta,
320        dynamic_multimask_stability_thresh: cfg.dynamic_multimask_stability_thresh,
321    })
322}
323
324fn extract_mlp(
325    weights: &mut WeightMap,
326    prefix: &str,
327    input_dim: usize,
328    hidden_dim: usize,
329    output_dim: usize,
330    num_layers: usize,
331    sigmoid_output: bool,
332) -> Result<Sam2HypernetMlp> {
333    let mut layers = Vec::with_capacity(num_layers);
334    for i in 0..num_layers {
335        let in_d = if i == 0 { input_dim } else { hidden_dim };
336        let out_d = if i + 1 == num_layers {
337            output_dim
338        } else {
339            hidden_dim
340        };
341        let (w, sh) = weights.take(&format!("{prefix}.layers.{i}.weight"))?;
342        ensure!(
343            sh == vec![out_d, in_d],
344            "{prefix}.layers.{i}.weight shape {sh:?} not [{out_d}, {in_d}]"
345        );
346        let (b, _) = weights.take(&format!("{prefix}.layers.{i}.bias"))?;
347        layers.push(Sam2MlpLayer { w, b, in_d, out_d });
348    }
349    Ok(Sam2HypernetMlp {
350        layers,
351        sigmoid_output,
352    })
353}
354
355/// Forward through a ReLU MLP. Final layer is NOT followed by ReLU;
356/// optional sigmoid is applied to the output.
357pub fn mlp_forward(mlp: &Sam2HypernetMlp, x: &[f32], rows: usize) -> Vec<f32> {
358    let mut cur = x.to_vec();
359    let n = mlp.layers.len();
360    for (i, layer) in mlp.layers.iter().enumerate() {
361        cur = linear(&cur, &layer.w, &layer.b, rows, layer.in_d, layer.out_d);
362        if i + 1 < n {
363            for v in cur.iter_mut() {
364                if *v < 0.0 {
365                    *v = 0.0;
366                }
367            }
368        }
369    }
370    if mlp.sigmoid_output {
371        for v in cur.iter_mut() {
372            *v = 1.0 / (1.0 + (-*v).exp());
373        }
374    }
375    cur
376}
377
378/// Output of [`mask_decoder_forward`].
379pub struct Sam2MaskDecoderOutput {
380    /// `[num_masks, h_out, w_out]` mask logits. `num_masks` is 1 or 3
381    /// depending on `multimask_output` (and dynamic-stability fallback).
382    pub masks: Vec<f32>,
383    pub iou_pred: Vec<f32>, // [num_masks]
384    pub num_masks: usize,
385    pub h_out: usize,
386    pub w_out: usize,
387    /// Selected mask token(s) for the object-pointer projection.
388    /// Shape `[num_ptr_tokens, transformer_dim]`. None if
389    /// `use_object_pointer=false`.
390    pub sam_tokens_out: Vec<f32>,
391    pub num_ptr_tokens: usize,
392    /// Object-score logits — `[1]` per batch when pred_obj_scores=true,
393    /// else a constant +10 (matching the reference) so downstream
394    /// `obj_score_prob` evaluates to ~1.
395    pub object_score_logits: Vec<f32>,
396    /// Object-pointer projection output `[num_ptr_tokens,
397    /// transformer_dim]`. None if `use_object_pointer=false`.
398    pub object_pointer: Option<Vec<f32>>,
399}
400
401/// Run the SAM 2 mask decoder.
402///
403/// `image_embeddings`: NCHW `[1, C=transformer_dim, grid, grid]`.
404/// `image_pe`: NCHW `[1, C=transformer_dim, grid, grid]`.
405/// `sparse_prompt_embeddings`: `[num_sparse, transformer_dim]`.
406/// `dense_prompt_embeddings`: `[transformer_dim, grid, grid]`.
407/// `high_res_features`: optional `(feat_s0, feat_s1)` where:
408///
409///   - `feat_s0`: stride-4 features `[transformer_dim, 4·grid, 4·grid]`
410///   - `feat_s1`: stride-8 features `[transformer_dim, 2·grid, 2·grid]`
411///
412///   Reference passes these from the FpnNeck.
413/// `grid`: spatial side of the image embeddings (64 for SAM 2).
414#[allow(clippy::too_many_arguments)]
415pub fn mask_decoder_forward(
416    w: &Sam2MaskDecoderWeights,
417    upscale: &mut Sam2MaskUpscaleCompiled,
418    hyper_matmul: Option<&mut MaskHyperMatmulCompiled>,
419    hyper_mlps_ir: Option<&mut [MlpReluCompiled]>,
420    iou_head_ir: Option<&mut MlpReluCompiled>,
421    obj_score_head_ir: Option<&mut MlpReluCompiled>,
422    obj_ptr_proj_ir: Option<&mut MlpReluCompiled>,
423    tw_ir: Option<&mut TwoWayTransformerCompiled>,
424    image_embeddings: &[f32],
425    image_pe: &[f32],
426    sparse_prompt_embeddings: &[f32],
427    num_sparse_tokens: usize,
428    dense_prompt_embeddings: &[f32],
429    high_res_features: Option<(&[f32], &[f32])>,
430    multimask_output: bool,
431    grid: usize,
432) -> Result<Sam2MaskDecoderOutput> {
433    let e = w.transformer_dim;
434    let nm = w.num_mask_tokens;
435    let g = grid;
436    ensure!(
437        image_embeddings.len() == e * g * g,
438        "image_embeddings len {} ≠ E·g·g ({e}·{g}·{g})",
439        image_embeddings.len()
440    );
441    ensure!(
442        image_pe.len() == e * g * g,
443        "image_pe len {} ≠ E·g·g",
444        image_pe.len()
445    );
446    ensure!(
447        dense_prompt_embeddings.len() == e * g * g,
448        "dense_prompt_embeddings len {} ≠ E·g·g",
449        dense_prompt_embeddings.len()
450    );
451    ensure!(
452        sparse_prompt_embeddings.len() == num_sparse_tokens * e,
453        "sparse_prompt_embeddings len {} ≠ num_sparse·E ({num_sparse_tokens}·{e})",
454        sparse_prompt_embeddings.len()
455    );
456    if w.use_high_res_features {
457        let (s0, s1) = high_res_features.ok_or_else(|| {
458            anyhow::anyhow!("use_high_res_features=true requires (feat_s0, feat_s1)")
459        })?;
460        ensure!(
461            s0.len() == e * (4 * g) * (4 * g),
462            "feat_s0 len {} ≠ E·4g·4g ({e}·{}·{})",
463            s0.len(),
464            4 * g,
465            4 * g
466        );
467        ensure!(
468            s1.len() == e * (2 * g) * (2 * g),
469            "feat_s1 len {} ≠ E·2g·2g ({e}·{}·{})",
470            s1.len(),
471            2 * g,
472            2 * g
473        );
474    }
475
476    // ── Build tokens = cat(maybe obj_score, iou, mask, sparse) ──
477    let s = if w.obj_score_token.is_some() { 1 } else { 0 };
478    let n_out_tokens = s + 1 + nm;
479    let q_n = n_out_tokens + num_sparse_tokens;
480    let mut tokens = Vec::with_capacity(q_n * e);
481    if let Some(obj) = &w.obj_score_token {
482        tokens.extend_from_slice(obj);
483    }
484    tokens.extend_from_slice(&w.iou_token);
485    tokens.extend_from_slice(&w.mask_tokens);
486    tokens.extend_from_slice(sparse_prompt_embeddings);
487
488    // ── src = image_embeddings + dense_prompt_embeddings ──
489    let mut src = image_embeddings.to_vec();
490    for i in 0..src.len() {
491        src[i] += dense_prompt_embeddings[i];
492    }
493    let pos_src = image_pe.to_vec();
494
495    // ── Run the two-way transformer ──
496    let k_n = g * g;
497    let (hs, src_post) = if let Some(tw) = tw_ir {
498        if tw.masked && q_n <= tw.max_q_n && tw.k_n == k_n {
499            tw.run_nchw_masked(&tokens, q_n, &src, &pos_src, g)?
500        } else if !tw.masked && q_n == tw.max_q_n && tw.k_n == k_n {
501            tw.run_nchw(&tokens, &src, &pos_src, g)?
502        } else {
503            two_way_transformer_forward(&w.transformer, &src, &pos_src, &tokens, 1, e, g, g, q_n)
504        }
505    } else {
506        two_way_transformer_forward(&w.transformer, &src, &pos_src, &tokens, 1, e, g, g, q_n)
507    };
508
509    let obj_score_logits_pre = if let Some(ir) = obj_score_head_ir {
510        ir.run(&hs[..e], 1)?
511    } else if let Some(head) = &w.obj_score_head {
512        let token = &hs[..e];
513        mlp_forward(head, token, 1)
514    } else {
515        // Reference returns a constant +10 logit when pred_obj_scores=false.
516        vec![10.0]
517    };
518
519    let iou_token_out: Vec<f32> = hs[s * e..(s + 1) * e].to_vec();
520    let mask_tokens_out = hs[(s + 1) * e..(s + 1 + nm) * e].to_vec();
521
522    // ── Reshape src_post [1, g·g, E] → [1, E, g, g] NCHW ──
523    let mut src_nchw = vec![0f32; e * g * g];
524    for ss in 0..g * g {
525        for c in 0..e {
526            src_nchw[c * g * g + ss] = src_post[ss * e + c];
527        }
528    }
529
530    // ── Upscaling via IR (optional high-res 1×1 fuse inside graph) ──
531    let q8 = e / 8;
532    let h2 = g * 4;
533    let w2 = g * 4;
534    let (feat_s0, feat_s1) = high_res_features.unwrap_or((&[] as &[f32], &[] as &[f32]));
535    let up2 = upscale.run(&src_nchw, feat_s1, feat_s0, g)?;
536
537    // ── Hypernetwork MLPs → [nm, q8] ──
538    let mut hyper_in = vec![0f32; nm * q8];
539    if let Some(mlps) = hyper_mlps_ir {
540        ensure!(
541            mlps.len() == nm,
542            "hyper_mlps_ir len {} ≠ num_mask_tokens {}",
543            mlps.len(),
544            nm
545        );
546        for i in 0..nm {
547            let token = &mask_tokens_out[i * e..(i + 1) * e];
548            let h = mlps[i].run(token, 1)?;
549            hyper_in[i * q8..(i + 1) * q8].copy_from_slice(&h);
550        }
551    } else {
552        for i in 0..nm {
553            let token = &mask_tokens_out[i * e..(i + 1) * e];
554            let h = mlp_forward(&w.hyper_mlps[i], token, 1);
555            hyper_in[i * q8..(i + 1) * q8].copy_from_slice(&h);
556        }
557    }
558    let spat = h2 * w2;
559    let mut masks_all = vec![0f32; nm * spat];
560    if let Some(hm) = hyper_matmul {
561        hm.run(&hyper_in, &up2, &mut masks_all)?;
562    } else {
563        rlx_cpu::blas::sgemm_auto(&hyper_in, &up2, &mut masks_all, nm, q8, spat);
564    }
565
566    // ── IoU head ──
567    let iou_pred_all = if let Some(head) = iou_head_ir {
568        head.run(&iou_token_out, 1)?
569    } else {
570        mlp_forward(&w.iou_head, &iou_token_out, 1)
571    };
572
573    // ── Multimask selection (with optional dynamic stability fallback) ──
574    let (masks, iou_pred, num_masks, ptr_indices): (Vec<f32>, Vec<f32>, usize, Vec<usize>) =
575        if multimask_output {
576            // [1:nm] = 3 masks for nm=4.
577            let masks = masks_all[spat..].to_vec();
578            let iou = iou_pred_all[1..].to_vec();
579            let ptr = if w.use_multimask_token_for_obj_ptr {
580                (1..nm).collect()
581            } else {
582                vec![0]
583            };
584            (masks, iou, nm - 1, ptr)
585        } else if w.dynamic_multimask_via_stability {
586            dynamic_multimask_via_stability(
587                &masks_all,
588                &iou_pred_all,
589                nm,
590                spat,
591                w.dynamic_multimask_stability_delta,
592                w.dynamic_multimask_stability_thresh,
593            )
594        } else {
595            let masks = masks_all[..spat].to_vec();
596            let iou = iou_pred_all[..1].to_vec();
597            (masks, iou, 1, vec![0])
598        };
599
600    let num_ptr_tokens = ptr_indices.len();
601    let mut sam_tokens_out = Vec::with_capacity(num_ptr_tokens * e);
602    for &pi in &ptr_indices {
603        sam_tokens_out.extend_from_slice(&mask_tokens_out[pi * e..(pi + 1) * e]);
604    }
605
606    let object_pointer = if let Some(ir) = obj_ptr_proj_ir {
607        if ir.compiled_rows() == num_ptr_tokens {
608            Some(ir.run(&sam_tokens_out, num_ptr_tokens)?)
609        } else {
610            w.obj_ptr_proj
611                .as_ref()
612                .map(|proj| mlp_forward(proj, &sam_tokens_out, num_ptr_tokens))
613        }
614    } else {
615        w.obj_ptr_proj
616            .as_ref()
617            .map(|proj| mlp_forward(proj, &sam_tokens_out, num_ptr_tokens))
618    };
619
620    Ok(Sam2MaskDecoderOutput {
621        masks,
622        iou_pred,
623        num_masks,
624        h_out: h2,
625        w_out: w2,
626        sam_tokens_out,
627        num_ptr_tokens,
628        object_score_logits: obj_score_logits_pre,
629        object_pointer,
630    })
631}
632
633/// Reference's `_dynamic_multimask_via_stability`: pick between the
634/// single-mask token (index 0) and the best multimask token (1..nm)
635/// based on a stability score that compares mask area at two thresholds.
636fn dynamic_multimask_via_stability(
637    masks_all: &[f32],
638    iou_pred_all: &[f32],
639    _nm: usize,
640    spat: usize,
641    delta: f32,
642    thresh: f32,
643) -> (Vec<f32>, Vec<f32>, usize, Vec<usize>) {
644    // multimask logits [nm-1, spat], iou [nm-1]
645    let mm_masks = &masks_all[spat..];
646    let mm_iou = &iou_pred_all[1..];
647    // Best multimask by predicted IoU.
648    let best = mm_iou
649        .iter()
650        .enumerate()
651        .fold((0usize, f32::NEG_INFINITY), |(bi, bv), (i, &v)| {
652            if v > bv { (i, v) } else { (bi, bv) }
653        })
654        .0;
655
656    // Stability score of single-mask token (index 0).
657    let single_mask = &masks_all[..spat];
658    let stability = mask_stability_score(single_mask, delta);
659    if stability >= thresh {
660        // Single mask is stable enough; use it.
661        (single_mask.to_vec(), iou_pred_all[..1].to_vec(), 1, vec![0])
662    } else {
663        // Fall back to the best multimask token.
664        let masks = mm_masks[best * spat..(best + 1) * spat].to_vec();
665        let iou = vec![mm_iou[best]];
666        // Pointer index in the *original* nm tokens: best+1.
667        (masks, iou, 1, vec![best + 1])
668    }
669}
670
671/// Stability score: `area(masks > +delta) / area(masks > -delta)`.
672/// Mirrors `_get_stability_scores` in the reference.
673fn mask_stability_score(mask_logits: &[f32], delta: f32) -> f32 {
674    let mut hi = 0u32;
675    let mut lo = 0u32;
676    for &v in mask_logits {
677        if v > delta {
678            hi += 1;
679        }
680        if v > -delta {
681            lo += 1;
682        }
683    }
684    if lo == 0 { 1.0 } else { hi as f32 / lo as f32 }
685}
686
687#[allow(dead_code)]
688fn _silence_add_inplace(x: &mut [f32], y: &[f32]) {
689    add_inplace(x, y);
690}