use serde::Deserialize;
pub const SAM_PIXEL_MEAN: [f32; 3] = [123.675, 116.28, 103.53];
pub const SAM_PIXEL_STD: [f32; 3] = [58.395, 57.12, 57.375];
pub const SAM_IMG_SIZE: usize = 1024;
pub const SAM_PATCH_SIZE: usize = 16;
pub const SAM_EMBED_HW: usize = SAM_IMG_SIZE / SAM_PATCH_SIZE;
pub const SAM_PROMPT_EMBED_DIM: usize = 256;
#[derive(Debug, Clone, Deserialize)]
pub struct SamEncoderConfig {
pub encoder_kind: EncoderKind,
pub embed_dim: usize,
pub depth: usize,
pub num_heads: usize,
pub global_attn_indexes: Vec<usize>,
pub window_size: usize,
pub use_rel_pos: bool,
pub use_abs_pos: bool,
pub qkv_bias: bool,
pub layer_norm_eps: f64,
pub out_chans: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
pub enum EncoderKind {
ViT,
TinyViT,
}
impl SamEncoderConfig {
pub fn vit_b() -> Self {
Self {
encoder_kind: EncoderKind::ViT,
embed_dim: 768,
depth: 12,
num_heads: 12,
global_attn_indexes: vec![2, 5, 8, 11],
window_size: 14,
use_rel_pos: true,
use_abs_pos: true,
qkv_bias: true,
layer_norm_eps: 1e-6,
out_chans: SAM_PROMPT_EMBED_DIM,
}
}
pub fn vit_l() -> Self {
Self {
embed_dim: 1024,
depth: 24,
num_heads: 16,
global_attn_indexes: vec![5, 11, 17, 23],
..Self::vit_b()
}
}
pub fn vit_h() -> Self {
Self {
embed_dim: 1280,
depth: 32,
num_heads: 16,
global_attn_indexes: vec![7, 15, 23, 31],
..Self::vit_b()
}
}
pub fn head_dim(&self) -> usize {
self.embed_dim / self.num_heads
}
pub fn num_patches_per_side(&self) -> usize {
SAM_EMBED_HW
}
}
#[derive(Debug, Clone)]
pub struct SamDecoderConfig {
pub transformer_dim: usize,
pub transformer_depth: usize,
pub transformer_num_heads: usize,
pub transformer_mlp_dim: usize,
pub num_mask_tokens: usize,
pub iou_head_depth: usize,
pub iou_head_hidden_dim: usize,
pub layer_norm_eps: f64,
}
impl Default for SamDecoderConfig {
fn default() -> Self {
Self {
transformer_dim: SAM_PROMPT_EMBED_DIM,
transformer_depth: 2,
transformer_num_heads: 8,
transformer_mlp_dim: 2048,
num_mask_tokens: 4,
iou_head_depth: 3,
iou_head_hidden_dim: SAM_PROMPT_EMBED_DIM,
layer_norm_eps: 1e-6,
}
}
}
#[derive(Debug, Clone)]
pub struct SamConfig {
pub encoder: SamEncoderConfig,
pub decoder: SamDecoderConfig,
}
impl SamConfig {
pub fn vit_b() -> Self {
Self {
encoder: SamEncoderConfig::vit_b(),
decoder: SamDecoderConfig::default(),
}
}
pub fn vit_l() -> Self {
Self {
encoder: SamEncoderConfig::vit_l(),
decoder: SamDecoderConfig::default(),
}
}
pub fn vit_h() -> Self {
Self {
encoder: SamEncoderConfig::vit_h(),
decoder: SamDecoderConfig::default(),
}
}
}