Skip to main content

mask_decoder_forward

Function mask_decoder_forward 

Source
pub fn mask_decoder_forward(
    w: &MaskDecoderWeights,
    upscale: &mut SamMaskUpscaleCompiled,
    hyper_matmul: Option<&mut MaskHyperMatmulCompiled>,
    hyper_mlps_ir: Option<&mut [MlpReluCompiled]>,
    iou_head_ir: Option<&mut MlpReluCompiled>,
    tw_ir: Option<&mut TwoWayTransformerCompiled>,
    image_embeddings: &[f32],
    image_pe: &[f32],
    sparse_prompt_embeddings: &[f32],
    num_sparse_tokens: usize,
    dense_prompt_embeddings: &[f32],
    multimask_output: bool,
) -> Result<(Vec<f32>, Vec<f32>, usize, usize)>
Expand description

Forward through the mask decoder, returning (masks, iou_pred).

image_embeddings: NCHW [1, C=256, hw, hw]. image_pe: NCHW [1, C=256, hw, hw]. sparse_prompt_embeddings: [1, num_sparse, E] (may have 0 sparse tokens). dense_prompt_embeddings: [1, E, hw, hw].

multimask_output: if true, return masks[…, 1:4] (3 candidates); else return masks[…, 0:1] (the single “best” output).

Output shapes:

  • masks: [1, num_masks, 4·hw, 4·hw] (num_masks = 3 if multimask_output else 1).
  • iou_pred: [1, num_masks].