1use super::config::{SAM2_IMG_SIZE, Sam2Config, Sam2DecoderConfig};
33use super::fpn_neck::{FpnLevel, FpnNeckWeights, apply_fpn_neck};
34use super::fpn_neck_ir::{Sam2FpnNeckIr, compile_fpn_neck_ir};
35use super::image_encoder::build_sam2_image_encoder_graph;
36use super::mask_decoder::{
37 Sam2MaskDecoderOutput, Sam2MaskDecoderWeights, extract_mask_decoder_weights,
38 mask_decoder_forward,
39};
40use super::memory_attention::{
41 Sam2MemoryAttentionWeights, extract_memory_attention_weights, memory_attention_forward,
42};
43use super::memory_attention_ir::{MemoryAttentionCompiled, max_memory_slots};
44use super::memory_encoder::{
45 Sam2MemoryEncoderOutput, Sam2MemoryEncoderWeights, extract_memory_encoder_weights,
46 memory_encoder_forward,
47};
48use super::preprocess::{Sam2PreprocessWeights, assemble_patch_tokens, preprocess_image};
49use super::prompt_encoder::{
50 SAM2_MASK_IN_CHANS, SAM2_PROMPT_GRID, Sam2PromptEncoderOutput, Sam2PromptEncoderWeights,
51 extract_prompt_encoder_weights, prompt_encoder_forward,
52};
53use super::prompt_mask_ir::Sam2PromptMaskCompiled;
54use super::upscale_ir::Sam2MaskUpscaleCompiled;
55use anyhow::{Result, ensure};
56use rlx_flow::CompileProfile;
57use rlx_runtime::{CompiledGraph, Device, Session};
58use rlx_sam::profile::sam2_profile_near_weights;
59use rlx_sam_ir::mask_hyper_matmul_ir::MaskHyperMatmulCompiled;
60use rlx_sam_ir::mlp_relu_ir::MlpReluCompiled;
61use std::path::Path;
62
63#[derive(Clone)]
65struct HieraOutputShapes {
66 stage_hw: Vec<(usize, usize)>,
67 stage_dims: Vec<usize>,
68}
69
70pub struct Sam2 {
74 cfg: Sam2Config,
75 encoder: CompiledGraph,
76 pre: Sam2PreprocessWeights,
77 fpn: FpnNeckWeights,
78 fpn_ir: Sam2FpnNeckIr,
79 prompt_enc: Sam2PromptEncoderWeights,
80 mask_dec: Sam2MaskDecoderWeights,
81 mask_stack: Sam2PromptMaskCompiled,
82 upscale: Sam2MaskUpscaleCompiled,
83 hyper_matmul: MaskHyperMatmulCompiled,
84 hyper_mlps_ir: Vec<MlpReluCompiled>,
85 iou_head_ir: MlpReluCompiled,
86 obj_score_head_ir: Option<MlpReluCompiled>,
87 obj_ptr_proj_ir: Option<MlpReluCompiled>,
88 tw_ir: rlx_sam_ir::twoway_transformer_ir::TwoWayTransformerCompiled,
89 mem_enc: Sam2MemoryEncoderWeights,
90 mem_attn: Sam2MemoryAttentionWeights,
91 mem_attn_ir: Option<MemoryAttentionCompiled>,
93 mem_attn_device: Device,
94 hiera_shapes: HieraOutputShapes,
95 compile_profile: CompileProfile,
96}
97
98impl Sam2 {
99 pub fn from_safetensors(weights_path: &str, cfg: Sam2Config) -> Result<Self> {
103 Self::from_safetensors_on(weights_path, cfg, Device::Cpu)
104 }
105
106 pub fn from_safetensors_on(
110 weights_path: &str,
111 cfg: Sam2Config,
112 device: Device,
113 ) -> Result<Self> {
114 rlx_core::validate_sam_device("sam2", device)?;
115 let mut wm =
116 rlx_core::load_weight_map(Path::new(weights_path), rlx_core::SAM2_GGUF_ARCHES)?;
117 let compile_profile = sam2_profile_near_weights(Path::new(weights_path));
118
119 let (graph, params, pre, fpn) = build_sam2_image_encoder_graph(&cfg.hiera, &mut wm)?;
122
123 let hiera_shapes = HieraOutputShapes {
124 stage_hw: (0..cfg.hiera.stages.len())
125 .map(|s| {
126 (
127 cfg.hiera.grid_size_at_stage(s),
128 cfg.hiera.grid_size_at_stage(s),
129 )
130 })
131 .collect(),
132 stage_dims: (0..cfg.hiera.stages.len())
133 .map(|s| cfg.hiera.embed_dim_at_stage(s))
134 .collect(),
135 };
136
137 let prompt_enc = extract_prompt_encoder_weights(
139 &mut wm,
140 cfg.decoder.transformer_dim,
141 SAM2_MASK_IN_CHANS,
142 )?;
143
144 let mask_dec = extract_mask_decoder_weights(&mut wm, &cfg.decoder)?;
146
147 let mut mem_enc = extract_memory_encoder_weights(&mut wm, &cfg.memory_encoder)?;
149 super::memory_encoder::compile_memory_encoder_ir(
150 &mut mem_enc,
151 SAM2_IMG_SIZE,
152 SAM2_IMG_SIZE,
153 SAM2_PROMPT_GRID,
154 SAM2_PROMPT_GRID,
155 device,
156 &compile_profile,
157 )?;
158
159 let mem_attn = extract_memory_attention_weights(&mut wm, &cfg.memory)?;
161 let grid = cfg.hiera.grid_size_at_stage(cfg.hiera.stages.len() - 1);
162 let mask_stack =
163 Sam2PromptMaskCompiled::compile_with_profile(&prompt_enc, device, &compile_profile)?;
164 let upscale = Sam2MaskUpscaleCompiled::compile_with_profile(
165 &mask_dec,
166 grid,
167 device,
168 &compile_profile,
169 )?;
170 let hyper_matmul = MaskHyperMatmulCompiled::compile_with_profile(
171 mask_dec.num_mask_tokens,
172 cfg.decoder.transformer_dim / 8,
173 grid,
174 device,
175 &compile_profile,
176 )?;
177 let hyper_mlps_ir = super::mlp_ir::compile_hyper_mlps_with_profile(
178 &mask_dec.hyper_mlps,
179 device,
180 &compile_profile,
181 )?;
182 let iou_head_ir = super::mlp_ir::compile_hyper_mlp_with_profile(
183 &mask_dec.iou_head,
184 device,
185 &compile_profile,
186 )?;
187 let obj_score_head_ir = super::mlp_ir::compile_optional_hyper_mlp_with_profile(
188 &mask_dec.obj_score_head,
189 1,
190 device,
191 &compile_profile,
192 )?;
193 let obj_ptr_rows = super::mlp_ir::obj_ptr_proj_rows(
194 mask_dec.num_mask_tokens,
195 mask_dec.use_multimask_token_for_obj_ptr,
196 );
197 let obj_ptr_proj_ir = super::mlp_ir::compile_optional_hyper_mlp_with_profile(
198 &mask_dec.obj_ptr_proj,
199 obj_ptr_rows,
200 device,
201 &compile_profile,
202 )?;
203 let s_tok = if mask_dec.obj_score_token.is_some() {
204 1
205 } else {
206 0
207 };
208 let base_q_n = s_tok + 1 + mask_dec.num_mask_tokens;
209 let grid = cfg.hiera.grid_size_at_stage(cfg.hiera.stages.len() - 1);
210 let tw_ir = super::transformer_ir::compile_two_way_transformer_with_profile(
211 &mask_dec.transformer,
212 base_q_n,
213 grid,
214 device,
215 &compile_profile,
216 )?;
217 let fpn_ir = compile_fpn_neck_ir(
218 &fpn,
219 &hiera_shapes.stage_hw,
220 &hiera_shapes.stage_dims,
221 device,
222 &compile_profile,
223 )?;
224
225 let opts = rlx_core::flow_bridge::compile_options_for_profile(&compile_profile, device);
226 let mut encoder = Session::new(device).compile_with(graph, &opts);
227 for (name, data) in ¶ms {
228 encoder.set_param(name, data);
229 }
230
231 Ok(Self {
236 cfg,
237 encoder,
238 pre,
239 fpn,
240 fpn_ir,
241 prompt_enc,
242 mask_dec,
243 mask_stack,
244 upscale,
245 hyper_matmul,
246 hyper_mlps_ir,
247 iou_head_ir,
248 obj_score_head_ir,
249 obj_ptr_proj_ir,
250 tw_ir,
251 mem_enc,
252 mem_attn,
253 mem_attn_ir: None,
254 mem_attn_device: device,
255 hiera_shapes,
256 compile_profile,
257 })
258 }
259
260 pub fn compile_profile(&self) -> &CompileProfile {
262 &self.compile_profile
263 }
264
265 pub fn config(&self) -> &Sam2Config {
266 &self.cfg
267 }
268
269 fn ensure_mem_attn_ir(&mut self) -> Result<()> {
270 if self.mem_attn_ir.is_some() {
271 return Ok(());
272 }
273 let [rope_x, rope_y] = self.cfg.memory.rope_feat_size;
274 let n_img_mem = rope_x * rope_y;
275 let max_n_mem = max_memory_slots(n_img_mem, self.cfg.memory.max_obj_ptrs_in_encoder);
276 self.mem_attn_ir = Some(if self.cfg.memory.mem_attn_in_graph_rope {
277 MemoryAttentionCompiled::compile_in_graph_rope_with_profile(
278 &self.mem_attn,
279 n_img_mem,
280 max_n_mem,
281 self.cfg.memory.max_obj_ptrs_in_encoder,
282 self.mem_attn_device,
283 &self.compile_profile,
284 )?
285 } else {
286 MemoryAttentionCompiled::compile_with_profile(
287 &self.mem_attn,
288 n_img_mem,
289 max_n_mem,
290 self.cfg.memory.max_obj_ptrs_in_encoder,
291 self.mem_attn_device,
292 &self.compile_profile,
293 )?
294 });
295 Ok(())
296 }
297
298 fn encode(&mut self, image_u8: &[u8], h_in: usize, w_in: usize) -> Result<Vec<FpnLevel>> {
301 let image_nchw = preprocess_image(image_u8, h_in, w_in);
302 let hidden = assemble_patch_tokens(&self.pre, &image_nchw)?;
303 let outputs = self.encoder.run(&[("hidden", hidden.as_slice())]);
304 ensure!(
305 outputs.len() == self.hiera_shapes.stage_dims.len(),
306 "encoder produced {} outputs (expected {})",
307 outputs.len(),
308 self.hiera_shapes.stage_dims.len()
309 );
310 apply_fpn_neck(
311 &self.fpn,
312 &mut self.fpn_ir,
313 &outputs,
314 &self.hiera_shapes.stage_hw,
315 &self.hiera_shapes.stage_dims,
316 )
317 }
318
319 pub fn predict_image(
335 &mut self,
336 image_u8: &[u8],
337 h_in: usize,
338 w_in: usize,
339 points: Option<(&[f32], &[f32])>,
340 boxes: Option<&[f32]>,
341 mask_input: Option<&[f32]>,
342 multimask_output: bool,
343 ) -> Result<Sam2ImagePrediction> {
344 let levels = self.encode(image_u8, h_in, w_in)?;
345 let prompt = self.run_prompt(points, boxes, mask_input)?;
349 let dec = self.run_decoder(&levels, &prompt, multimask_output)?;
350
351 Ok(Sam2ImagePrediction {
352 masks: dec.masks,
353 iou_pred: dec.iou_pred,
354 num_masks: dec.num_masks,
355 h_out: dec.h_out,
356 w_out: dec.w_out,
357 object_score_logits: dec.object_score_logits,
358 object_pointer: dec.object_pointer,
359 })
360 }
361
362 fn run_prompt(
363 &mut self,
364 points: Option<(&[f32], &[f32])>,
365 boxes: Option<&[f32]>,
366 mask_input: Option<&[f32]>,
367 ) -> Result<Sam2PromptEncoderOutput> {
368 prompt_encoder_forward(
369 &self.prompt_enc,
370 &mut self.mask_stack,
371 points,
372 boxes,
373 mask_input,
374 )
375 }
376
377 fn run_decoder(
378 &mut self,
379 levels: &[FpnLevel],
380 prompt: &Sam2PromptEncoderOutput,
381 multimask_output: bool,
382 ) -> Result<Sam2MaskDecoderOutput> {
383 let lvl_stride16 = &levels[2]; let lvl_stride8 = &levels[1]; let lvl_stride4 = &levels[0]; let high_res_features = if self.mask_dec.use_high_res_features {
388 Some((
389 lvl_stride4.features.as_slice(),
390 lvl_stride8.features.as_slice(),
391 ))
392 } else {
393 None
394 };
395
396 ensure!(
397 lvl_stride16.h == SAM2_PROMPT_GRID && lvl_stride16.w == SAM2_PROMPT_GRID,
398 "stride-16 FPN level must be {}×{} (got {}×{})",
399 SAM2_PROMPT_GRID,
400 SAM2_PROMPT_GRID,
401 lvl_stride16.h,
402 lvl_stride16.w
403 );
404
405 mask_decoder_forward(
406 &self.mask_dec,
407 &mut self.upscale,
408 Some(&mut self.hyper_matmul),
409 Some(&mut self.hyper_mlps_ir),
410 Some(&mut self.iou_head_ir),
411 self.obj_score_head_ir.as_mut(),
412 self.obj_ptr_proj_ir.as_mut(),
413 Some(&mut self.tw_ir),
414 &lvl_stride16.features,
415 &lvl_stride16.pos,
416 &prompt.sparse_embeddings,
417 prompt.num_sparse_tokens,
418 &prompt.dense_embeddings,
419 high_res_features,
420 multimask_output,
421 SAM2_PROMPT_GRID,
422 )
423 }
424
425 pub fn predict_video_frame(
434 &mut self,
435 state: &mut Sam2VideoState,
436 image_u8: &[u8],
437 h_in: usize,
438 w_in: usize,
439 points: Option<(&[f32], &[f32])>,
440 boxes: Option<&[f32]>,
441 mask_input: Option<&[f32]>,
442 multimask_output: bool,
443 ) -> Result<Sam2ImagePrediction> {
444 let levels = self.encode(image_u8, h_in, w_in)?;
445
446 let stride32 = &levels[3];
450 let mut conditioned_stride32: Vec<f32> = stride32.features.clone();
451 if !state.memory.is_empty() {
452 let curr = nchw_to_seq_c(
453 &stride32.features,
454 self.cfg.memory.d_model,
455 stride32.h,
456 stride32.w,
457 );
458 let curr_pos = nchw_to_seq_c(
459 &stride32.pos,
460 self.cfg.memory.d_model,
461 stride32.h,
462 stride32.w,
463 );
464
465 let (memory_flat, memory_pos_flat, n_mem) =
466 state.assembled_memory(self.cfg.memory.kv_in_dim, self.cfg.memory.mem_dim);
467 let n_img = stride32.h * stride32.w;
468 let num_ptr = state.num_obj_ptr_tokens(self.cfg.memory.mem_dim);
469 self.ensure_mem_attn_ir()?;
470 let ir = self.mem_attn_ir.as_mut().expect("mem_attn_ir");
471 let attn_out = if n_img == ir.n_img && n_mem <= ir.max_n_mem {
472 ir.run(
473 &curr,
474 &curr_pos,
475 &memory_flat,
476 &memory_pos_flat,
477 n_mem,
478 num_ptr,
479 )?
480 } else {
481 memory_attention_forward(
482 &self.mem_attn,
483 &curr,
484 &curr_pos,
485 &memory_flat,
486 &memory_pos_flat,
487 n_img,
488 n_mem,
489 self.cfg.memory.kv_in_dim,
490 num_ptr,
491 )?
492 };
493 conditioned_stride32 =
495 seq_c_to_nchw(&attn_out, self.cfg.memory.d_model, stride32.h, stride32.w);
496 }
497
498 let mut levels = levels;
504 levels[3].features = conditioned_stride32;
505
506 let prompt = self.run_prompt(points, boxes, mask_input)?;
507 let dec = self.run_decoder(&levels, &prompt, multimask_output)?;
508
509 let stride16 = &levels[2];
512 let mem = run_memory_encoder(&mut self.mem_enc, &stride16.features, &dec)?;
513 state.push_frame_memory(
514 mem,
515 dec.object_pointer.clone(),
516 self.cfg.memory.max_obj_ptrs_in_encoder,
517 );
518
519 Ok(Sam2ImagePrediction {
520 masks: dec.masks,
521 iou_pred: dec.iou_pred,
522 num_masks: dec.num_masks,
523 h_out: dec.h_out,
524 w_out: dec.w_out,
525 object_score_logits: dec.object_score_logits,
526 object_pointer: dec.object_pointer,
527 })
528 }
529}
530
531pub struct Sam2ImagePrediction {
534 pub masks: Vec<f32>,
535 pub iou_pred: Vec<f32>,
536 pub num_masks: usize,
537 pub h_out: usize,
538 pub w_out: usize,
539 pub object_score_logits: Vec<f32>,
540 pub object_pointer: Option<Vec<f32>>,
541}
542
543pub struct Sam2VideoState {
547 pub memory: Vec<Sam2MemoryEncoderOutput>,
549 pub obj_ptr_queue: Vec<Vec<f32>>,
550}
551
552impl Sam2VideoState {
553 pub fn new() -> Self {
554 Self {
555 memory: Vec::new(),
556 obj_ptr_queue: Vec::new(),
557 }
558 }
559
560 pub fn num_obj_ptr_tokens(&self, _mem_dim: usize) -> usize {
564 self.obj_ptr_queue.len()
570 }
571
572 pub fn assembled_memory(
578 &self,
579 kv_in_dim: usize,
580 _mem_dim: usize,
581 ) -> (Vec<f32>, Vec<f32>, usize) {
582 let mut features = Vec::new();
583 let mut positions = Vec::new();
584 let mut total_tokens = 0usize;
585
586 for m in &self.memory {
587 let tokens = m.h * m.w;
588 let mut feat_seq = vec![0f32; tokens * kv_in_dim];
590 let mut pos_seq = vec![0f32; tokens * kv_in_dim];
591 let pe_chans = m.pos.len() / (m.h * m.w);
592 for t in 0..tokens {
593 for c in 0..kv_in_dim {
594 feat_seq[t * kv_in_dim + c] = m.features[c * tokens + t];
595 }
596 for c in 0..kv_in_dim.min(pe_chans) {
599 pos_seq[t * kv_in_dim + c] = m.pos[c * tokens + t];
600 }
601 }
602 features.extend_from_slice(&feat_seq);
603 positions.extend_from_slice(&pos_seq);
604 total_tokens += tokens;
605 }
606
607 for ptr in &self.obj_ptr_queue {
610 ensure_or_zero(&mut features, &mut positions, ptr, kv_in_dim);
611 total_tokens += 1;
612 }
613
614 (features, positions, total_tokens)
615 }
616
617 fn push_frame_memory(
618 &mut self,
619 mem: Sam2MemoryEncoderOutput,
620 obj_ptr: Option<Vec<f32>>,
621 max_ptrs: usize,
622 ) {
623 self.memory.push(mem);
624 if let Some(p) = obj_ptr {
625 self.obj_ptr_queue.push(p);
626 while self.obj_ptr_queue.len() > max_ptrs {
627 self.obj_ptr_queue.remove(0);
628 }
629 }
630 }
631}
632
633impl Default for Sam2VideoState {
634 fn default() -> Self {
635 Self::new()
636 }
637}
638
639fn ensure_or_zero(
640 features: &mut Vec<f32>,
641 positions: &mut Vec<f32>,
642 ptr: &[f32],
643 kv_in_dim: usize,
644) {
645 if ptr.len() == kv_in_dim {
646 features.extend_from_slice(ptr);
647 } else {
648 let take = ptr.len().min(kv_in_dim);
655 features.extend_from_slice(&ptr[..take]);
656 for _ in take..kv_in_dim {
657 features.push(0.0);
658 }
659 }
660 for _ in 0..kv_in_dim {
661 positions.push(0.0);
662 }
663}
664
665fn run_memory_encoder(
666 mem_enc: &mut Sam2MemoryEncoderWeights,
667 pix_feat: &[f32],
668 dec: &Sam2MaskDecoderOutput,
669) -> Result<Sam2MemoryEncoderOutput> {
670 let m_chunk = dec.h_out * dec.w_out;
675 ensure!(
676 dec.masks.len() >= m_chunk,
677 "decoder produced empty mask buffer"
678 );
679 let mask0 = &dec.masks[..m_chunk];
680
681 let mut up_mask = vec![0f32; SAM2_IMG_SIZE * SAM2_IMG_SIZE];
685 bilinear_upsample_1ch(
686 mask0,
687 dec.h_out,
688 dec.w_out,
689 &mut up_mask,
690 SAM2_IMG_SIZE,
691 SAM2_IMG_SIZE,
692 );
693
694 memory_encoder_forward(
695 mem_enc,
696 pix_feat,
697 &up_mask,
698 SAM2_PROMPT_GRID,
699 SAM2_PROMPT_GRID,
700 false,
701 )
702}
703
704fn bilinear_upsample_1ch(src: &[f32], sh: usize, sw: usize, dst: &mut [f32], dh: usize, dw: usize) {
705 let sx = (sw as f32) / (dw as f32);
706 let sy = (sh as f32) / (dh as f32);
707 for y in 0..dh {
708 let yf = (y as f32 + 0.5) * sy - 0.5;
709 let y0 = yf.floor().max(0.0) as usize;
710 let y1 = (y0 + 1).min(sh - 1);
711 let dy = (yf - yf.floor()).clamp(0.0, 1.0);
712 for x in 0..dw {
713 let xf = (x as f32 + 0.5) * sx - 0.5;
714 let x0 = xf.floor().max(0.0) as usize;
715 let x1 = (x0 + 1).min(sw - 1);
716 let dx = (xf - xf.floor()).clamp(0.0, 1.0);
717 let p00 = src[y0 * sw + x0];
718 let p01 = src[y0 * sw + x1];
719 let p10 = src[y1 * sw + x0];
720 let p11 = src[y1 * sw + x1];
721 let top = p00 * (1.0 - dx) + p01 * dx;
722 let bot = p10 * (1.0 - dx) + p11 * dx;
723 dst[y * dw + x] = top * (1.0 - dy) + bot * dy;
724 }
725 }
726}
727
728fn nchw_to_seq_c(src: &[f32], c: usize, h: usize, w: usize) -> Vec<f32> {
729 let mut out = vec![0f32; h * w * c];
730 for y in 0..h {
731 for x in 0..w {
732 for ch in 0..c {
733 out[(y * w + x) * c + ch] = src[ch * h * w + y * w + x];
734 }
735 }
736 }
737 out
738}
739
740fn seq_c_to_nchw(src: &[f32], c: usize, h: usize, w: usize) -> Vec<f32> {
741 let mut out = vec![0f32; c * h * w];
742 for y in 0..h {
743 for x in 0..w {
744 for ch in 0..c {
745 out[ch * h * w + y * w + x] = src[(y * w + x) * c + ch];
746 }
747 }
748 }
749 out
750}
751
752#[allow(dead_code)]
753fn _silence_decoder_cfg(_d: &Sam2DecoderConfig) {}