Skip to main content

Crate rlx_sam

Crate rlx_sam 

Source
Expand description

SAM v1 — Meta’s Segment Anything image-segmentation model.

§Phasing

Phase 1 (this commit) lands the image encoder end-to-end:

  • Host-side preprocessing (resize-to-1024, ImageNet pixel normalization, zero-pad to 1024×1024, patch embedding via Conv2d-as-matmul).
  • IR graph for the 12 encoder blocks with windowed + global attention, decomposed relative position embeddings, plain GELU-tanh MLPs, pre-norm residual structure.
  • IR neck (Conv2d 1×1 → LN2d → Conv2d 3×3 → LN2d → [256, 64, 64]).

Phase 1 status: 100% numerical parity with candle’s ImageEncoderViT::forward() on real sam_vit_b_01ec64.safetensors weights — max |Δ| = 7.15e-6 on the 1×256×64×64 image embeddings (full 12-layer ViT-B at 1024×1024 input). Phase-1 bisect env vars remain in tests/sam_parity.rs for future debugging:

  • RLX_SAM_DEBUG_DEPTH=N — run only the first N encoder blocks
  • RLX_SAM_DEBUG_NO_RELPOS=1 — disable decomposed relative pos
  • RLX_SAM_DEBUG_FORCE_GLOBAL=1 — force every block to use global attn
  • RLX_SAM_DEBUG_ZERO_RELH=1 / RLX_SAM_DEBUG_ZERO_RELW=1 — zero a single rel_pos axis (data only — the matmul + add still execute)

Phase 2 (next commit) lands the prompt encoder + mask decoder:

  • Random Fourier positional encoding, point/box/mask embeddings.
  • Two-way transformer between prompt tokens and image embeddings.
  • ConvTranspose2d upscaling (IR) + hypernetwork MLPs for mask + IoU output.

Weight key convention matches Meta / candle exactly so the lmz/candle-sam safetensors checkpoints load with no remapping.

Re-exports§

pub use config::EncoderKind;
pub use config::SAM_EMBED_HW;
pub use config::SAM_IMG_SIZE;
pub use config::SAM_PATCH_SIZE;
pub use config::SAM_PIXEL_MEAN;
pub use config::SAM_PIXEL_STD;
pub use config::SAM_PROMPT_EMBED_DIM;
pub use config::SamConfig;
pub use config::SamDecoderConfig;
pub use config::SamEncoderConfig;
pub use flow::SamEncoderBuilt;
pub use flow::SamEncoderFlow;
pub use flow::build_sam_encoder_built;
pub use image_encoder::NeckWeights;
pub use image_encoder::apply_neck_host;
pub use image_encoder::build_sam_encoder_graph;
pub use image_encoder::build_sam_encoder_hir;
pub use mask_decoder::MaskDecoderWeights;
pub use mask_decoder::mask_decoder_forward;
pub use preprocess::SamPreprocessWeights;
pub use preprocess::assemble_patch_tokens;
pub use preprocess::preprocess_image;
pub use profile::SAM_PROFILE_FILE;
pub use profile::sam_profile_default;
pub use profile::sam_profile_near_weights;
pub use profile::sam2_profile_default;
pub use profile::sam2_profile_near_weights;
pub use profile::sam3_profile_default;
pub use profile::sam3_profile_near_weights;
pub use prompt_encoder::PromptEncoderOutput;
pub use prompt_encoder::PromptEncoderWeights;
pub use prompt_encoder::prompt_encoder_forward;
pub use sam::MaskPrediction;
pub use sam::SAM_MASK_IN_CHANS;
pub use sam::Sam;
pub use sam::sam_vit_b_config;
pub use transformer::TwoWayTransformerWeights;
pub use transformer::attention_forward;
pub use transformer::two_way_transformer_forward;

Modules§

cli
config
SAM v1 model configuration. Mirrors Meta’s segment-anything Python reference and candle’s segment_anything module.
flow
Tier-0 SAM v1 image encoder flow.
image_encoder
SAM v1 ViT image encoder HIR builder.
mask_decoder
SAM v1 mask decoder — transformer host-side; upscaling via IR graph.
mlp_ir
Compile SAM1 mask-decoder ReLU MLP heads to IR.
preprocess
SAM v1 host-side preprocessing.
profile
Tier-1 compile profiles for SAM v1 / SAM2 / SAM3 loaders.
prompt_encoder
SAM v1 prompt encoder — Fourier/point embeddings host-side; mask downscaling stack compiled via super::prompt_mask_ir.
prompt_mask_ir
SAM v1 prompt-encoder mask downscale (IR).
sam
SAM v1 top-level orchestrator — ties the IR-graph image encoder together with the host-side prompt encoder + mask decoder.
transformer
SAM v1 two-way transformer — host-side.
transformer_ir
Compile SAM v1 two-way transformer to IR.
upscale_ir
SAM v1 mask-decoder upscaling subgraph (ConvTranspose2d + LN2d + GELU).

Enums§

Device
Re-export Device so callers can construct it without depending on rlx-runtime themselves. Target device for graph execution.