pub fn mask_decoder_forward(
w: &MaskDecoderWeights,
upscale: &mut SamMaskUpscaleCompiled,
hyper_matmul: Option<&mut MaskHyperMatmulCompiled>,
hyper_mlps_ir: Option<&mut [MlpReluCompiled]>,
iou_head_ir: Option<&mut MlpReluCompiled>,
tw_ir: Option<&mut TwoWayTransformerCompiled>,
image_embeddings: &[f32],
image_pe: &[f32],
sparse_prompt_embeddings: &[f32],
num_sparse_tokens: usize,
dense_prompt_embeddings: &[f32],
multimask_output: bool,
) -> Result<(Vec<f32>, Vec<f32>, usize, usize)>Expand description
Forward through the mask decoder, returning (masks, iou_pred).
image_embeddings: NCHW [1, C=256, hw, hw].
image_pe: NCHW [1, C=256, hw, hw].
sparse_prompt_embeddings: [1, num_sparse, E] (may have 0 sparse tokens).
dense_prompt_embeddings: [1, E, hw, hw].
multimask_output: if true, return masks[…, 1:4] (3 candidates);
else return masks[…, 0:1] (the single “best” output).
Output shapes:
- masks:
[1, num_masks, 4·hw, 4·hw](num_masks = 3 if multimask_output else 1). - iou_pred:
[1, num_masks].