use serde::Deserialize;
pub const SAM2_PIXEL_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
pub const SAM2_PIXEL_STD: [f32; 3] = [0.229, 0.224, 0.225];
pub const SAM2_IMG_SIZE: usize = 1024;
pub const SAM2_PATCH_KERNEL: usize = 7;
pub const SAM2_PATCH_STRIDE: usize = 4;
pub const SAM2_PATCH_PADDING: usize = 3;
pub const SAM2_PATCH_GRID: usize = SAM2_IMG_SIZE / SAM2_PATCH_STRIDE;
pub const SAM2_Q_POOL_COUNT: usize = 3;
pub const SAM2_Q_STRIDE: usize = 2;
pub const SAM2_PROMPT_EMBED_DIM: usize = 256;
#[derive(Debug, Clone, Deserialize)]
pub struct Sam2HieraConfig {
pub embed_dim: usize,
pub num_heads: usize,
pub stages: Vec<usize>,
pub global_att_blocks: Vec<usize>,
pub window_pos_embed_bkg_spatial_size: [usize; 2],
pub window_spec: [usize; 4],
pub layer_norm_eps: f64,
pub mlp_ratio: f64,
pub qkv_bias: bool,
pub fpn_out_chans: usize,
}
impl Sam2HieraConfig {
pub fn tiny() -> Self {
Self {
embed_dim: 96,
num_heads: 1,
stages: vec![1, 2, 7, 2],
global_att_blocks: vec![5, 7, 9],
window_pos_embed_bkg_spatial_size: [7, 7],
window_spec: [8, 4, 14, 7],
layer_norm_eps: 1e-6,
mlp_ratio: 4.0,
qkv_bias: true,
fpn_out_chans: SAM2_PROMPT_EMBED_DIM,
}
}
pub fn small() -> Self {
Self {
stages: vec![1, 2, 11, 2],
global_att_blocks: vec![7, 10, 13],
..Self::tiny()
}
}
pub fn base_plus() -> Self {
Self {
embed_dim: 112,
num_heads: 2,
stages: vec![2, 3, 16, 3],
global_att_blocks: vec![12, 16, 20],
window_pos_embed_bkg_spatial_size: [14, 14],
window_spec: [8, 4, 14, 7],
layer_norm_eps: 1e-6,
mlp_ratio: 4.0,
qkv_bias: true,
fpn_out_chans: SAM2_PROMPT_EMBED_DIM,
}
}
pub fn large() -> Self {
Self {
embed_dim: 144,
num_heads: 2,
stages: vec![2, 6, 36, 4],
global_att_blocks: vec![23, 33, 43],
window_pos_embed_bkg_spatial_size: [7, 7],
window_spec: [8, 4, 16, 8],
..Self::base_plus()
}
}
pub fn total_blocks(&self) -> usize {
self.stages.iter().sum()
}
pub fn q_pool_block_indices(&self) -> Vec<usize> {
let mut acc = 0usize;
let mut out = Vec::with_capacity(SAM2_Q_POOL_COUNT);
for &n in &self.stages[..self.stages.len() - 1] {
acc += n;
out.push(acc);
}
out
}
pub fn stage_of_block(&self, block_idx: usize) -> usize {
let mut acc = 0usize;
for (si, &n) in self.stages.iter().enumerate() {
acc += n;
if block_idx < acc {
return si;
}
}
self.stages.len() - 1
}
pub fn embed_dim_at_stage(&self, s: usize) -> usize {
self.embed_dim * (1 << s)
}
pub fn num_heads_at_stage(&self, s: usize) -> usize {
self.num_heads * (1 << s)
}
pub fn window_size_at_stage(&self, s: usize) -> usize {
self.window_spec[s]
}
pub fn grid_size_at_stage(&self, s: usize) -> usize {
SAM2_PATCH_GRID / (1 << s)
}
}
#[derive(Debug, Clone)]
pub struct Sam2FpnConfig {
pub d_model: usize,
pub backbone_channel_list: Vec<usize>,
pub fpn_top_down_levels: Vec<usize>,
pub interpolation_nearest: bool,
}
impl Sam2FpnConfig {
pub fn for_hiera(cfg: &Sam2HieraConfig) -> Self {
let channels: Vec<usize> = (0..cfg.stages.len())
.rev()
.map(|s| cfg.embed_dim_at_stage(s))
.collect();
debug_assert!(
channels.first().copied().unwrap_or(0) >= channels.last().copied().unwrap_or(0),
"backbone_channel_list must be coarse → fine"
);
Self {
d_model: cfg.fpn_out_chans,
backbone_channel_list: channels,
fpn_top_down_levels: vec![2, 3],
interpolation_nearest: true,
}
}
}
#[derive(Debug, Clone)]
pub struct Sam2DecoderConfig {
pub transformer_dim: usize,
pub transformer_depth: usize,
pub transformer_num_heads: usize,
pub transformer_mlp_dim: usize,
pub num_mask_tokens: usize,
pub iou_head_depth: usize,
pub iou_head_hidden_dim: usize,
pub iou_prediction_use_sigmoid: bool,
pub use_object_pointer: bool,
pub use_mlp_for_obj_ptr_proj: bool,
pub pred_obj_scores: bool,
pub pred_obj_scores_mlp: bool,
pub use_multimask_token_for_obj_ptr: bool,
pub use_high_res_features: bool,
pub dynamic_multimask_via_stability: bool,
pub dynamic_multimask_stability_delta: f32,
pub dynamic_multimask_stability_thresh: f32,
pub layer_norm_eps: f64,
}
impl Default for Sam2DecoderConfig {
fn default() -> Self {
Self {
transformer_dim: SAM2_PROMPT_EMBED_DIM,
transformer_depth: 2,
transformer_num_heads: 8,
transformer_mlp_dim: 2048,
num_mask_tokens: 4,
iou_head_depth: 3,
iou_head_hidden_dim: SAM2_PROMPT_EMBED_DIM,
iou_prediction_use_sigmoid: true,
use_object_pointer: true,
use_mlp_for_obj_ptr_proj: true,
pred_obj_scores: true,
pred_obj_scores_mlp: true,
use_multimask_token_for_obj_ptr: true,
use_high_res_features: true,
dynamic_multimask_via_stability: true,
dynamic_multimask_stability_delta: 0.05,
dynamic_multimask_stability_thresh: 0.98,
layer_norm_eps: 1e-6,
}
}
}
#[derive(Debug, Clone)]
pub struct Sam2MemoryEncoderConfig {
pub in_dim: usize,
pub out_dim: usize,
pub mask_downsampler_kernel: usize,
pub mask_downsampler_stride: usize,
pub mask_downsampler_padding: usize,
pub mask_downsampler_total_stride: usize,
pub fuser_num_layers: usize,
pub fuser_dim: usize,
pub fuser_kernel: usize,
pub fuser_padding: usize,
pub fuser_layer_scale_init_value: f32,
pub fuser_use_dwconv: bool,
pub fuser_input_projection: bool,
pub pe_num_pos_feats: usize,
pub pe_temperature: f32,
}
impl Default for Sam2MemoryEncoderConfig {
fn default() -> Self {
Self {
in_dim: SAM2_PROMPT_EMBED_DIM,
out_dim: 64,
mask_downsampler_kernel: 3,
mask_downsampler_stride: 2,
mask_downsampler_padding: 1,
mask_downsampler_total_stride: 16,
fuser_num_layers: 2,
fuser_dim: SAM2_PROMPT_EMBED_DIM,
fuser_kernel: 7,
fuser_padding: 3,
fuser_layer_scale_init_value: 1e-6,
fuser_use_dwconv: true,
fuser_input_projection: false,
pe_num_pos_feats: 32,
pe_temperature: 10000.0,
}
}
}
#[derive(Debug, Clone)]
pub struct Sam2MemoryConfig {
pub d_model: usize,
pub num_layers: usize,
pub num_heads: usize,
pub dim_feedforward: usize,
pub layer_norm_eps: f64,
pub kv_in_dim: usize,
pub rope_theta: f32,
pub rope_feat_size: [usize; 2],
pub rope_k_repeat: bool,
pub pos_enc_at_input: bool,
pub pos_enc_at_attn: bool,
pub pos_enc_at_cross_attn_keys: bool,
pub pos_enc_at_cross_attn_queries: bool,
pub max_obj_ptrs_in_encoder: usize,
pub mem_attn_in_graph_rope: bool,
pub mem_dim: usize,
}
impl Default for Sam2MemoryConfig {
fn default() -> Self {
Self {
d_model: SAM2_PROMPT_EMBED_DIM,
num_layers: 4,
num_heads: 1,
dim_feedforward: 2048,
layer_norm_eps: 1e-5,
kv_in_dim: 64,
rope_theta: 10000.0,
rope_feat_size: [64, 64],
rope_k_repeat: true,
pos_enc_at_input: true,
pos_enc_at_attn: false,
pos_enc_at_cross_attn_keys: true,
pos_enc_at_cross_attn_queries: false,
max_obj_ptrs_in_encoder: 16,
mem_dim: 64,
mem_attn_in_graph_rope: false,
}
}
}
#[derive(Debug, Clone)]
pub struct Sam2Config {
pub hiera: Sam2HieraConfig,
pub fpn: Sam2FpnConfig,
pub decoder: Sam2DecoderConfig,
pub memory: Sam2MemoryConfig,
pub memory_encoder: Sam2MemoryEncoderConfig,
}
impl Sam2Config {
pub fn hiera_tiny() -> Self {
let hiera = Sam2HieraConfig::tiny();
let fpn = Sam2FpnConfig::for_hiera(&hiera);
Self {
hiera,
fpn,
decoder: Sam2DecoderConfig::default(),
memory: Sam2MemoryConfig::default(),
memory_encoder: Sam2MemoryEncoderConfig::default(),
}
}
pub fn hiera_small() -> Self {
let hiera = Sam2HieraConfig::small();
let fpn = Sam2FpnConfig::for_hiera(&hiera);
Self {
hiera,
fpn,
decoder: Sam2DecoderConfig::default(),
memory: Sam2MemoryConfig::default(),
memory_encoder: Sam2MemoryEncoderConfig::default(),
}
}
pub fn hiera_base_plus() -> Self {
let hiera = Sam2HieraConfig::base_plus();
let fpn = Sam2FpnConfig::for_hiera(&hiera);
Self {
hiera,
fpn,
decoder: Sam2DecoderConfig::default(),
memory: Sam2MemoryConfig::default(),
memory_encoder: Sam2MemoryEncoderConfig::default(),
}
}
pub fn hiera_large() -> Self {
let hiera = Sam2HieraConfig::large();
let fpn = Sam2FpnConfig::for_hiera(&hiera);
Self {
hiera,
fpn,
decoder: Sam2DecoderConfig::default(),
memory: Sam2MemoryConfig::default(),
memory_encoder: Sam2MemoryEncoderConfig::default(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn q_pool_indices_match_reference() {
assert_eq!(
Sam2HieraConfig::tiny().q_pool_block_indices(),
vec![1, 3, 10]
);
assert_eq!(
Sam2HieraConfig::small().q_pool_block_indices(),
vec![1, 3, 14]
);
assert_eq!(
Sam2HieraConfig::base_plus().q_pool_block_indices(),
vec![2, 5, 21]
);
assert_eq!(
Sam2HieraConfig::large().q_pool_block_indices(),
vec![2, 8, 44]
);
}
#[test]
fn stage_dim_and_head_doubling() {
let cfg = Sam2HieraConfig::base_plus();
assert_eq!(cfg.embed_dim_at_stage(0), 112);
assert_eq!(cfg.embed_dim_at_stage(1), 224);
assert_eq!(cfg.embed_dim_at_stage(2), 448);
assert_eq!(cfg.embed_dim_at_stage(3), 896);
assert_eq!(cfg.num_heads_at_stage(3), 16);
}
#[test]
fn grid_halves_per_stage() {
let cfg = Sam2HieraConfig::base_plus();
assert_eq!(cfg.grid_size_at_stage(0), 256);
assert_eq!(cfg.grid_size_at_stage(1), 128);
assert_eq!(cfg.grid_size_at_stage(2), 64);
assert_eq!(cfg.grid_size_at_stage(3), 32);
}
}