pub fn sam2_mask_decoder_forward(
w: &Sam2MaskDecoderWeights,
upscale: &mut Sam2MaskUpscaleCompiled,
hyper_matmul: Option<&mut MaskHyperMatmulCompiled>,
hyper_mlps_ir: Option<&mut [MlpReluCompiled]>,
iou_head_ir: Option<&mut MlpReluCompiled>,
obj_score_head_ir: Option<&mut MlpReluCompiled>,
obj_ptr_proj_ir: Option<&mut MlpReluCompiled>,
tw_ir: Option<&mut TwoWayTransformerCompiled>,
image_embeddings: &[f32],
image_pe: &[f32],
sparse_prompt_embeddings: &[f32],
num_sparse_tokens: usize,
dense_prompt_embeddings: &[f32],
high_res_features: Option<(&[f32], &[f32])>,
multimask_output: bool,
grid: usize,
) -> Result<Sam2MaskDecoderOutput, Error>Expand description
Run the SAM 2 mask decoder.
image_embeddings: NCHW [1, C=transformer_dim, grid, grid].
image_pe: NCHW [1, C=transformer_dim, grid, grid].
sparse_prompt_embeddings: [num_sparse, transformer_dim].
dense_prompt_embeddings: [transformer_dim, grid, grid].
high_res_features: optional (feat_s0, feat_s1) where:
feat_s0: stride-4 features[transformer_dim, 4·grid, 4·grid]feat_s1: stride-8 features[transformer_dim, 2·grid, 2·grid]
Reference passes these from the FpnNeck.
grid: spatial side of the image embeddings (64 for SAM 2).