Skip to main content

Module memory_encoder

Module memory_encoder 

Source
Expand description

SAM 2 memory encoder — host-side.

Mirrors sam2/modeling/memory_encoder.py exactly:

  MemoryEncoder(pix_feat, masks):
    masks = sigmoid(masks) if not skip_mask_sigmoid
    masks = MaskDownSampler(masks)        # 1×1024×1024 → 256×64×64
    pix_feat = pix_feat_proj(pix_feat)    # 1×1 conv 256→256
    x = pix_feat + masks
    x = Fuser(x)                          # 2 × CXBlock
    x = out_proj(x)                       # 1×1 conv 256→out_dim (64)
    pos = PositionEmbeddingSine(x)        # sinusoidal 2-D PE
    return (x, pos)

MaskDownSampler is a stack of log_stride(total_stride) blocks of Conv2d(k,s,p) → LayerNorm2d → GELU that grow the channel dim by stride² each step (1 → 4 → 16 → 64 → 256 for the default stride=2, total_stride=16). A final 1×1 conv projects to embed_dim=in_dim=256.

Fuser is a ConvNeXt-style stack — depthwise Conv k=7 → LN → pointwise Linear (4× expansion) → GELU → pointwise Linear → optional per-channel gamma (LayerScale) → residual.

Structs§

DownSampleLevel
Sam2CXBlockWeights
Sam2FuserWeights
Sam2MaskDownSamplerWeights
Sam2MemoryEncoderOutput
Sam2MemoryEncoderWeights

Functions§

compile_memory_encoder_ir
Compile memory-encoder IR subgraphs (mask down, pix 1×1, fuser, optional out 1×1).
compile_memory_mask_ir
Back-compat alias for mask-downsampler-only compile.
extract_memory_encoder_weights
memory_encoder_forward
Run the SAM 2 memory encoder.