Skip to main content

sam2_mask_decoder_forward

Function sam2_mask_decoder_forward 

Source
pub fn sam2_mask_decoder_forward(
    w: &Sam2MaskDecoderWeights,
    upscale: &mut Sam2MaskUpscaleCompiled,
    hyper_matmul: Option<&mut MaskHyperMatmulCompiled>,
    hyper_mlps_ir: Option<&mut [MlpReluCompiled]>,
    iou_head_ir: Option<&mut MlpReluCompiled>,
    obj_score_head_ir: Option<&mut MlpReluCompiled>,
    obj_ptr_proj_ir: Option<&mut MlpReluCompiled>,
    tw_ir: Option<&mut TwoWayTransformerCompiled>,
    image_embeddings: &[f32],
    image_pe: &[f32],
    sparse_prompt_embeddings: &[f32],
    num_sparse_tokens: usize,
    dense_prompt_embeddings: &[f32],
    high_res_features: Option<(&[f32], &[f32])>,
    multimask_output: bool,
    grid: usize,
) -> Result<Sam2MaskDecoderOutput, Error>
Expand description

Run the SAM 2 mask decoder.

image_embeddings: NCHW [1, C=transformer_dim, grid, grid]. image_pe: NCHW [1, C=transformer_dim, grid, grid]. sparse_prompt_embeddings: [num_sparse, transformer_dim]. dense_prompt_embeddings: [transformer_dim, grid, grid]. high_res_features: optional (feat_s0, feat_s1) where:

  • feat_s0: stride-4 features [transformer_dim, 4·grid, 4·grid]
  • feat_s1: stride-8 features [transformer_dim, 2·grid, 2·grid]

Reference passes these from the FpnNeck. grid: spatial side of the image embeddings (64 for SAM 2).