1use super::config::{SAM_EMBED_HW, SAM_IMG_SIZE, SAM_PROMPT_EMBED_DIM, SamConfig};
24use super::image_encoder::build_sam_encoder_graph;
25use super::mask_decoder::{MaskDecoderWeights, extract_mask_decoder_weights, mask_decoder_forward};
26use super::preprocess::{SamPreprocessWeights, assemble_patch_tokens, preprocess_image};
27use super::prompt_encoder::{
28 PromptEncoderOutput, PromptEncoderWeights, extract_prompt_encoder_weights,
29 prompt_encoder_forward,
30};
31use super::prompt_mask_ir::SamPromptMaskCompiled;
32use super::upscale_ir::SamMaskUpscaleCompiled;
33use anyhow::Result;
34use rlx_runtime::{CompiledGraph, Device, Session};
35use rlx_sam_ir::mask_hyper_matmul_ir::MaskHyperMatmulCompiled;
36use rlx_sam_ir::mlp_relu_ir::MlpReluCompiled;
37use std::path::Path;
38
39pub const SAM_MASK_IN_CHANS: usize = 16;
42
43pub struct Sam {
47 cfg: SamConfig,
48 encoder: CompiledGraph,
49 pre: SamPreprocessWeights,
50 prompt_enc: PromptEncoderWeights,
51 mask_stack: SamPromptMaskCompiled,
52 mask_dec: MaskDecoderWeights,
53 upscale: SamMaskUpscaleCompiled,
54 hyper_matmul: MaskHyperMatmulCompiled,
55 hyper_mlps_ir: Vec<MlpReluCompiled>,
56 iou_head_ir: MlpReluCompiled,
57 tw_ir: rlx_sam_ir::twoway_transformer_ir::TwoWayTransformerCompiled,
58}
59
60impl Sam {
61 pub fn from_safetensors(weights_path: &str, cfg: SamConfig) -> Result<Self> {
66 Self::from_safetensors_on(weights_path, cfg, Device::Cpu)
67 }
68
69 pub fn from_safetensors_on(weights_path: &str, cfg: SamConfig, device: Device) -> Result<Self> {
83 rlx_core::validate_sam_device("sam", device)?;
84 let mut wm = rlx_core::load_weight_map(Path::new(weights_path), rlx_core::SAM_GGUF_ARCHES)?;
85 let (graph, params, pre) = build_sam_encoder_graph(&cfg.encoder, &mut wm)?;
86 let profile = crate::profile::sam_profile_near_weights(std::path::Path::new(weights_path));
87 let opts = rlx_core::flow_bridge::compile_options_for_profile(&profile, device);
88 let mut encoder = Session::new(device).compile_with(graph, &opts);
89 for (name, data) in ¶ms {
90 encoder.set_param(name, data);
91 }
92 let prompt_enc =
93 extract_prompt_encoder_weights(&mut wm, cfg.encoder.out_chans, SAM_MASK_IN_CHANS)?;
94 let mask_stack =
95 SamPromptMaskCompiled::compile_with_profile(&prompt_enc, device, &profile)?;
96 let mask_dec = extract_mask_decoder_weights(
97 &mut wm,
98 cfg.decoder.transformer_dim,
99 cfg.decoder.num_mask_tokens,
100 cfg.decoder.iou_head_depth,
101 cfg.decoder.iou_head_hidden_dim,
102 cfg.decoder.transformer_depth,
103 cfg.decoder.transformer_num_heads,
104 cfg.decoder.transformer_mlp_dim,
105 )?;
106 let upscale = SamMaskUpscaleCompiled::compile_with_profile(&mask_dec, device, &profile)?;
107 let hyper_matmul = MaskHyperMatmulCompiled::compile_with_profile(
108 mask_dec.num_mask_tokens,
109 cfg.decoder.transformer_dim / 8,
110 SAM_EMBED_HW,
111 device,
112 &profile,
113 )?;
114 let hyper_mlps_ir =
115 super::mlp_ir::compile_hyper_mlps_with_profile(&mask_dec.hyper_mlps, device, &profile)?;
116 let iou_head_ir =
117 super::mlp_ir::compile_iou_head_with_profile(&mask_dec.iou_head, device, &profile)?;
118 let base_q_n = 1 + mask_dec.num_mask_tokens;
119 let tw_ir = super::transformer_ir::compile_two_way_transformer_with_profile(
120 &mask_dec.transformer,
121 base_q_n,
122 SAM_EMBED_HW,
123 device,
124 &profile,
125 )?;
126 Ok(Self {
127 cfg,
128 encoder,
129 pre,
130 prompt_enc,
131 mask_stack,
132 mask_dec,
133 upscale,
134 hyper_matmul,
135 hyper_mlps_ir,
136 iou_head_ir,
137 tw_ir,
138 })
139 }
140
141 pub fn encode_image(&mut self, image_nchw: &[f32]) -> Vec<f32> {
145 let hidden = assemble_patch_tokens(&self.pre, image_nchw).expect("assemble_patch_tokens");
146 let outputs = self.encoder.run(&[("hidden", hidden.as_slice())]);
147 outputs.into_iter().next().expect("encoder output")
148 }
149
150 pub fn predict_masks(
152 &mut self,
153 image_embeddings: &[f32],
154 points: Option<(&[f32], &[f32])>,
155 boxes: Option<&[f32]>,
156 masks: Option<&[f32]>,
157 multimask_output: bool,
158 ) -> Result<MaskPrediction> {
159 let pe: PromptEncoderOutput =
160 prompt_encoder_forward(&self.prompt_enc, &mut self.mask_stack, points, boxes, masks)?;
161 let (mask_logits, iou_pred, num_masks, mask_side) = mask_decoder_forward(
162 &self.mask_dec,
163 &mut self.upscale,
164 Some(&mut self.hyper_matmul),
165 Some(&mut self.hyper_mlps_ir),
166 Some(&mut self.iou_head_ir),
167 Some(&mut self.tw_ir),
168 image_embeddings,
169 &pe.image_pe,
170 &pe.sparse_embeddings,
171 pe.num_sparse_tokens,
172 &pe.dense_embeddings,
173 multimask_output,
174 )?;
175 Ok(MaskPrediction {
176 mask_logits,
177 iou_pred,
178 num_masks,
179 mask_side,
180 })
181 }
182
183 pub fn forward(
185 &mut self,
186 rgb: &[u8],
187 h_in: usize,
188 w_in: usize,
189 points: Option<(&[f32], &[f32])>,
190 boxes: Option<&[f32]>,
191 masks: Option<&[f32]>,
192 multimask_output: bool,
193 ) -> Result<(MaskPrediction, (usize, usize))> {
194 let (image_nchw, (resized_h, resized_w)) = preprocess_image(rgb, h_in, w_in);
195 let image_embeddings = self.encode_image(&image_nchw);
196 let pred = self.predict_masks(&image_embeddings, points, boxes, masks, multimask_output)?;
197 Ok((pred, (resized_h, resized_w)))
198 }
199
200 pub fn config(&self) -> &SamConfig {
201 &self.cfg
202 }
203
204 pub fn mask_side(&self) -> usize {
207 4 * SAM_EMBED_HW
208 }
209
210 pub fn input_image_size(&self) -> usize {
212 SAM_IMG_SIZE
213 }
214}
215
216pub struct MaskPrediction {
218 pub mask_logits: Vec<f32>,
222 pub iou_pred: Vec<f32>,
225 pub num_masks: usize,
226 pub mask_side: usize,
227}
228
229impl MaskPrediction {
230 pub fn best_by_iou(&self) -> Option<(usize, f32)> {
233 self.iou_pred
234 .iter()
235 .enumerate()
236 .max_by(|a, b| a.1.total_cmp(b.1))
237 .map(|(i, v)| (i, *v))
238 }
239}
240
241pub fn sam_vit_b_config() -> SamConfig {
244 SamConfig::vit_b()
245}
246
247#[allow(dead_code)]
248fn _silence_unused() {
249 let _ = SAM_PROMPT_EMBED_DIM;
250}