use super::config::{SAM_EMBED_HW, SAM_IMG_SIZE, SAM_PROMPT_EMBED_DIM, SamConfig};
use super::image_encoder::build_sam_encoder_graph;
use super::mask_decoder::{MaskDecoderWeights, extract_mask_decoder_weights, mask_decoder_forward};
use super::preprocess::{SamPreprocessWeights, assemble_patch_tokens, preprocess_image};
use super::prompt_encoder::{
PromptEncoderOutput, PromptEncoderWeights, extract_prompt_encoder_weights,
prompt_encoder_forward,
};
use super::prompt_mask_ir::SamPromptMaskCompiled;
use super::upscale_ir::SamMaskUpscaleCompiled;
use anyhow::Result;
use rlx_runtime::{CompiledGraph, Device, Session};
use rlx_sam_ir::mask_hyper_matmul_ir::MaskHyperMatmulCompiled;
use rlx_sam_ir::mlp_relu_ir::MlpReluCompiled;
use std::path::Path;
pub const SAM_MASK_IN_CHANS: usize = 16;
pub struct Sam {
cfg: SamConfig,
encoder: CompiledGraph,
pre: SamPreprocessWeights,
prompt_enc: PromptEncoderWeights,
mask_stack: SamPromptMaskCompiled,
mask_dec: MaskDecoderWeights,
upscale: SamMaskUpscaleCompiled,
hyper_matmul: MaskHyperMatmulCompiled,
hyper_mlps_ir: Vec<MlpReluCompiled>,
iou_head_ir: MlpReluCompiled,
tw_ir: rlx_sam_ir::twoway_transformer_ir::TwoWayTransformerCompiled,
}
impl Sam {
pub fn from_safetensors(weights_path: &str, cfg: SamConfig) -> Result<Self> {
Self::from_safetensors_on(weights_path, cfg, Device::Cpu)
}
pub fn from_safetensors_on(weights_path: &str, cfg: SamConfig, device: Device) -> Result<Self> {
rlx_core::validate_sam_device("sam", device)?;
let mut wm = rlx_core::load_weight_map(Path::new(weights_path), rlx_core::SAM_GGUF_ARCHES)?;
let (graph, params, pre) = build_sam_encoder_graph(&cfg.encoder, &mut wm)?;
let profile = crate::profile::sam_profile_near_weights(std::path::Path::new(weights_path));
let opts = rlx_core::flow_bridge::compile_options_for_profile(&profile, device);
let mut encoder = Session::new(device).compile_with(graph, &opts);
for (name, data) in ¶ms {
encoder.set_param(name, data);
}
let prompt_enc =
extract_prompt_encoder_weights(&mut wm, cfg.encoder.out_chans, SAM_MASK_IN_CHANS)?;
let mask_stack =
SamPromptMaskCompiled::compile_with_profile(&prompt_enc, device, &profile)?;
let mask_dec = extract_mask_decoder_weights(
&mut wm,
cfg.decoder.transformer_dim,
cfg.decoder.num_mask_tokens,
cfg.decoder.iou_head_depth,
cfg.decoder.iou_head_hidden_dim,
cfg.decoder.transformer_depth,
cfg.decoder.transformer_num_heads,
cfg.decoder.transformer_mlp_dim,
)?;
let upscale = SamMaskUpscaleCompiled::compile_with_profile(&mask_dec, device, &profile)?;
let hyper_matmul = MaskHyperMatmulCompiled::compile_with_profile(
mask_dec.num_mask_tokens,
cfg.decoder.transformer_dim / 8,
SAM_EMBED_HW,
device,
&profile,
)?;
let hyper_mlps_ir =
super::mlp_ir::compile_hyper_mlps_with_profile(&mask_dec.hyper_mlps, device, &profile)?;
let iou_head_ir =
super::mlp_ir::compile_iou_head_with_profile(&mask_dec.iou_head, device, &profile)?;
let base_q_n = 1 + mask_dec.num_mask_tokens;
let tw_ir = super::transformer_ir::compile_two_way_transformer_with_profile(
&mask_dec.transformer,
base_q_n,
SAM_EMBED_HW,
device,
&profile,
)?;
Ok(Self {
cfg,
encoder,
pre,
prompt_enc,
mask_stack,
mask_dec,
upscale,
hyper_matmul,
hyper_mlps_ir,
iou_head_ir,
tw_ir,
})
}
pub fn encode_image(&mut self, image_nchw: &[f32]) -> Vec<f32> {
let hidden = assemble_patch_tokens(&self.pre, image_nchw).expect("assemble_patch_tokens");
let outputs = self.encoder.run(&[("hidden", hidden.as_slice())]);
outputs.into_iter().next().expect("encoder output")
}
pub fn predict_masks(
&mut self,
image_embeddings: &[f32],
points: Option<(&[f32], &[f32])>,
boxes: Option<&[f32]>,
masks: Option<&[f32]>,
multimask_output: bool,
) -> Result<MaskPrediction> {
let pe: PromptEncoderOutput =
prompt_encoder_forward(&self.prompt_enc, &mut self.mask_stack, points, boxes, masks)?;
let (mask_logits, iou_pred, num_masks, mask_side) = mask_decoder_forward(
&self.mask_dec,
&mut self.upscale,
Some(&mut self.hyper_matmul),
Some(&mut self.hyper_mlps_ir),
Some(&mut self.iou_head_ir),
Some(&mut self.tw_ir),
image_embeddings,
&pe.image_pe,
&pe.sparse_embeddings,
pe.num_sparse_tokens,
&pe.dense_embeddings,
multimask_output,
)?;
Ok(MaskPrediction {
mask_logits,
iou_pred,
num_masks,
mask_side,
})
}
pub fn forward(
&mut self,
rgb: &[u8],
h_in: usize,
w_in: usize,
points: Option<(&[f32], &[f32])>,
boxes: Option<&[f32]>,
masks: Option<&[f32]>,
multimask_output: bool,
) -> Result<(MaskPrediction, (usize, usize))> {
let (image_nchw, (resized_h, resized_w)) = preprocess_image(rgb, h_in, w_in);
let image_embeddings = self.encode_image(&image_nchw);
let pred = self.predict_masks(&image_embeddings, points, boxes, masks, multimask_output)?;
Ok((pred, (resized_h, resized_w)))
}
pub fn config(&self) -> &SamConfig {
&self.cfg
}
pub fn mask_side(&self) -> usize {
4 * SAM_EMBED_HW
}
pub fn input_image_size(&self) -> usize {
SAM_IMG_SIZE
}
}
pub struct MaskPrediction {
pub mask_logits: Vec<f32>,
pub iou_pred: Vec<f32>,
pub num_masks: usize,
pub mask_side: usize,
}
impl MaskPrediction {
pub fn best_by_iou(&self) -> Option<(usize, f32)> {
self.iou_pred
.iter()
.enumerate()
.max_by(|a, b| a.1.total_cmp(b.1))
.map(|(i, v)| (i, *v))
}
}
pub fn sam_vit_b_config() -> SamConfig {
SamConfig::vit_b()
}
#[allow(dead_code)]
fn _silence_unused() {
let _ = SAM_PROMPT_EMBED_DIM;
}