1use super::config::Sam3Config;
24use super::detector::{Sam3DetectorWeights, detector_forward_native};
25use super::detector_decoder::{
26 Sam3DecoderOutput, Sam3DecoderWeights, extract_decoder_weights, forward_decoder,
27};
28use super::detector_encoder::{Sam3EncoderWeights, extract_encoder_weights, forward_encoder};
29use super::detector_encoder_ir::{forward_encoder_ir, forward_encoder_ir_on_with_profile};
30use super::geometry::{Sam3GeometryWeights, encode_geometry_native};
31use super::neck::{
32 Sam3NeckWeights, apply_neck_native, compile_neck_branches, extract_neck_weights,
33};
34use super::preprocess::{assemble_patch_tokens, preprocess_image};
35use super::segmentation_head::{
36 Sam3DotProductScoringWeights, Sam3SegmentationHeadWeights, Sam3SegmentationOutput,
37 compile_segmentation_ir, extract_dot_product_scoring_weights,
38 extract_segmentation_head_weights, forward_dot_prod_scoring, forward_segmentation,
39 segmentation_forward_native,
40};
41use super::text_encoder::{
42 Sam3TextEncoded, Sam3TextEncoderWeights, encode_text_native, encode_tokens,
43 extract_text_encoder_weights,
44};
45use super::tracker::{Sam3TrackerWeights, extract_tracker_weights, tracker_forward_native};
46use super::vision_encoder::{
47 Sam3VisionEncoderWeights, encode_image_native, extract_vision_encoder_weights,
48};
49use anyhow::{Context, Result, ensure};
50use rlx_flow::CompileProfile;
51use rlx_runtime::Device;
52use rlx_sam::profile::sam3_profile_near_weights;
53use std::path::Path;
54
55#[derive(Debug, Clone)]
56pub struct Sam3EncodedImage {
57 pub patch_tokens: Vec<f32>,
59 pub grid: usize,
60 pub embed_dim: usize,
61 pub resized_hw: (usize, usize),
62}
63
64#[derive(Debug, Clone)]
65pub struct Sam3ImagePrediction {
66 pub masks: Vec<f32>,
69 pub mask_shape: Vec<usize>,
70 pub boxes: Vec<f32>,
71 pub boxes_shape: Vec<usize>,
72 pub scores: Vec<f32>,
73 pub scores_shape: Vec<usize>,
74 pub num_instances: usize,
75 pub h_out: usize,
76 pub w_out: usize,
77}
78
79#[derive(Debug, Clone, Default)]
80pub struct Sam3VideoState {
81 pub frame_index: usize,
82 pub memory_tokens: Vec<Vec<f32>>,
83 pub last_prediction: Option<Sam3ImagePrediction>,
84}
85
86#[derive(Debug, Clone)]
87pub struct Sam3VideoFramePrediction {
88 pub frame_index: usize,
89 pub image: Sam3ImagePrediction,
90 pub memory_len: usize,
91}
92
93pub struct Sam3 {
94 cfg: Sam3Config,
95 vision: Option<Sam3VisionEncoderWeights>,
96 neck: Sam3NeckWeights,
97 text: Sam3TextEncoderWeights,
98 geometry: Sam3GeometryWeights,
99 detector: Sam3DetectorWeights,
100 encoder: Sam3EncoderWeights,
101 decoder: Sam3DecoderWeights,
102 seg_head: Sam3SegmentationHeadWeights,
103 scoring: Sam3DotProductScoringWeights,
104 seg: Sam3SegmentationHeadWeights,
105 tracker: Sam3TrackerWeights,
106 device: Device,
107 compile_profile: CompileProfile,
108 gguf_packed: Option<rlx_flow::GgufPackedParams>,
109}
110
111impl Sam3 {
112 pub fn from_checkpoint(weights_path: &str, cfg: Sam3Config) -> Result<Self> {
119 Self::from_checkpoint_on(weights_path, cfg, Device::Cpu)
120 }
121
122 pub fn from_checkpoint_on(weights_path: &str, cfg: Sam3Config, device: Device) -> Result<Self> {
123 Self::from_safetensors_on(weights_path, cfg, device)
124 }
125
126 pub fn from_safetensors(weights_path: &str, cfg: Sam3Config) -> Result<Self> {
127 Self::from_safetensors_on(weights_path, cfg, Device::Cpu)
128 }
129
130 pub fn from_safetensors_on(
131 weights_path: &str,
132 cfg: Sam3Config,
133 device: Device,
134 ) -> Result<Self> {
135 rlx_core::validate_sam_device("sam3", device)?;
136
137 let path = Path::new(weights_path);
138 let is_gguf = path.extension().is_some_and(|e| e == "gguf");
139 if is_gguf {
140 rlx_core::gguf_validate_arch(path, rlx_core::SAM3_GGUF_ARCHES)?;
141 }
142 let (mut wm, gguf_packed) = if is_gguf && crate::packed_gguf::gguf_has_packed_linears(path)?
143 {
144 eprintln!("[sam3] loading GGUF with packed ViT matmul {path:?}");
145 let (wm, packed) = crate::packed_gguf::load_sam3_from_gguf(path)?;
146 (wm, Some(packed))
147 } else {
148 (
149 rlx_core::load_weight_map(path, rlx_core::SAM3_GGUF_ARCHES)?,
150 None,
151 )
152 };
153 let compile_profile = sam3_profile_near_weights(path);
154 let vision = extract_vision_encoder_weights(&mut wm, &cfg.vit, gguf_packed.as_ref())?;
155 let mut neck = extract_neck_weights(&mut wm)?;
156 compile_neck_branches(
157 &mut neck,
158 cfg.vit.embed_dim,
159 cfg.vit.patch_grid(),
160 device,
161 &compile_profile,
162 )?;
163 let text = extract_text_encoder_weights(&mut wm, &cfg.text, gguf_packed.as_ref())?;
164 let encoder = extract_encoder_weights(&mut wm, gguf_packed.as_ref())?;
165 let decoder = extract_decoder_weights(&mut wm, gguf_packed.as_ref())?;
166 let mut seg_head = extract_segmentation_head_weights(&mut wm, gguf_packed.as_ref())?;
167 compile_segmentation_ir(
168 &mut seg_head,
169 gguf_packed.as_ref(),
170 cfg.vit.patch_grid(),
171 device,
172 &compile_profile,
173 )?;
174 let scoring = extract_dot_product_scoring_weights(&mut wm, gguf_packed.as_ref())?;
175 let tracker = extract_tracker_weights(&mut wm)?;
176 Ok(Self {
177 cfg,
178 vision: Some(vision),
179 neck,
180 text,
181 geometry: Sam3GeometryWeights::default(),
182 detector: Sam3DetectorWeights::default(),
183 encoder,
184 seg: Sam3SegmentationHeadWeights::default(),
185 tracker,
186 decoder,
187 seg_head,
188 scoring,
189 device,
190 compile_profile,
191 gguf_packed,
192 })
193 }
194
195 pub fn compile_profile(&self) -> &CompileProfile {
197 &self.compile_profile
198 }
199
200 pub fn config(&self) -> &Sam3Config {
201 &self.cfg
202 }
203
204 pub fn tracker_weights(&self) -> &Sam3TrackerWeights {
207 &self.tracker
208 }
209
210 pub fn encoder_weights(&self) -> &Sam3EncoderWeights {
211 &self.encoder
212 }
213
214 pub fn decoder_weights(&self) -> &Sam3DecoderWeights {
215 &self.decoder
216 }
217
218 pub fn device(&self) -> Device {
219 self.device
220 }
221
222 pub fn encode_image(
223 &self,
224 image_u8: &[u8],
225 h_in: usize,
226 w_in: usize,
227 ) -> Result<Sam3EncodedImage> {
228 let vision = self
229 .vision
230 .as_ref()
231 .context("SAM3 encode_image requires native vision weights")?;
232 let (image_nchw, resized_hw) = preprocess_image(image_u8, h_in, w_in);
233 let encoded = encode_image_native(
234 vision,
235 self.gguf_packed.as_ref(),
236 &self.cfg.vit,
237 &image_nchw,
238 )?;
239 Ok(Sam3EncodedImage {
240 patch_tokens: encoded.tokens,
241 grid: encoded.grid,
242 embed_dim: encoded.dim,
243 resized_hw,
244 })
245 }
246
247 pub fn predict_image_text(
256 &mut self,
257 image_u8: &[u8],
258 h_in: usize,
259 w_in: usize,
260 tokens: &[u32],
261 ) -> Result<Sam3ImagePrediction> {
262 let cfg = &self.cfg;
263 let nq = 200;
264 let seq_len = tokens.len();
265
266 let vision = self
268 .vision
269 .as_ref()
270 .context("predict_image_text requires native vision weights")?;
271 let (image_nchw, resized_hw) = preprocess_image(image_u8, h_in, w_in);
272 let vision_out = super::vision_encoder::encode_image_native(
273 vision,
274 self.gguf_packed.as_ref(),
275 &cfg.vit,
276 &image_nchw,
277 )?;
278 let levels = apply_neck_native(&mut self.neck, &vision_out)?;
279 let kept = &levels[..3];
281 let backbone_fpn: Vec<Vec<f32>> = kept.iter().map(|l| l.features.clone()).collect();
282 let backbone_shapes: Vec<(usize, usize)> = kept.iter().map(|l| (l.h, l.w)).collect();
283 let src_level = &kept[2];
285 let h = src_level.h;
286 let w = src_level.w;
287 let batch = 1;
288
289 let text_out = encode_tokens(
291 &self.text,
292 tokens,
293 batch,
294 seq_len,
295 self.gguf_packed.as_ref(),
296 )?;
297
298 let memory_bf = forward_encoder(
300 &self.encoder,
301 &src_level.features,
302 &src_level.pos,
303 &text_out.text_memory_resized,
304 &text_out.attention_mask,
305 batch,
306 h,
307 w,
308 seq_len,
309 self.gguf_packed.as_ref(),
310 )?;
311 let mut memory_pos = vec![0f32; batch * h * w * 256];
313 for b in 0..batch {
314 for y in 0..h {
315 for xc in 0..w {
316 for c in 0..256 {
317 memory_pos[(b * h * w + y * w + xc) * 256 + c] =
318 src_level.pos[((b * 256 + c) * h + y) * w + xc];
319 }
320 }
321 }
322 }
323
324 let dec = forward_decoder(
326 &self.decoder,
327 &memory_bf,
328 &memory_pos,
329 &text_out.text_memory_resized,
330 &text_out.attention_mask,
331 batch,
332 h,
333 w,
334 seq_len,
335 self.gguf_packed.as_ref(),
336 )?;
337
338 let num_layers = dec.num_layers;
340 let mut queries_last_bf = vec![0f32; batch * nq * 256];
341 let li = num_layers - 1;
342 for q in 0..nq {
343 for b in 0..batch {
344 let src = ((li * nq + q) * batch + b) * 256;
345 let dst = (b * nq + q) * 256;
346 queries_last_bf[dst..dst + 256].copy_from_slice(&dec.intermediate[src..src + 256]);
347 }
348 }
349
350 let mut ref_last_bf = vec![0f32; batch * nq * 4];
352 for q in 0..nq {
353 for b in 0..batch {
354 let src = ((li * nq + q) * batch + b) * 4;
355 let dst = (b * nq + q) * 4;
356 ref_last_bf[dst..dst + 4]
357 .copy_from_slice(&dec.intermediate_ref_boxes[src..src + 4]);
358 }
359 }
360
361 let delta = super::detector_decoder::bbox_embed_forward(
363 &self.decoder,
364 &queries_last_bf,
365 batch * nq,
366 self.gguf_packed.as_ref(),
367 )?;
368 let mut final_boxes_cxcywh = vec![0f32; batch * nq * 4];
369 for q in 0..nq {
370 for b in 0..batch {
371 let rb = &ref_last_bf[(b * nq + q) * 4..(b * nq + q + 1) * 4];
372 let d = &delta[(b * nq + q) * 4..(b * nq + q + 1) * 4];
373 let out_off = (b * nq + q) * 4;
374 for k in 0..4 {
375 let inv = if rb[k] <= 0.0 {
376 (1e-3f32 / (1.0 - 1e-3)).ln()
377 } else if rb[k] >= 1.0 {
378 ((1.0 - 1e-3) / 1e-3f32).ln()
379 } else {
380 (rb[k].max(1e-3) / (1.0 - rb[k]).max(1e-3)).ln()
381 };
382 let s = inv + d[k];
383 final_boxes_cxcywh[out_off + k] = 1.0 / (1.0 + (-s).exp());
384 }
385 }
386 }
387 let mut boxes_xyxy = vec![0f32; batch * nq * 4];
389 for i in 0..(batch * nq) {
390 let cx = final_boxes_cxcywh[i * 4];
391 let cy = final_boxes_cxcywh[i * 4 + 1];
392 let bw = final_boxes_cxcywh[i * 4 + 2];
393 let bh = final_boxes_cxcywh[i * 4 + 3];
394 boxes_xyxy[i * 4] = cx - 0.5 * bw;
395 boxes_xyxy[i * 4 + 1] = cy - 0.5 * bh;
396 boxes_xyxy[i * 4 + 2] = cx + 0.5 * bw;
397 boxes_xyxy[i * 4 + 3] = cy + 0.5 * bh;
398 }
399
400 let mut hs_bf = vec![0f32; num_layers * batch * nq * 256];
402 for l in 0..num_layers {
403 for q in 0..nq {
404 for b in 0..batch {
405 let src = ((l * nq + q) * batch + b) * 256;
406 let dst = ((l * batch + b) * nq + q) * 256;
407 hs_bf[dst..dst + 256].copy_from_slice(&dec.intermediate[src..src + 256]);
408 }
409 }
410 }
411 let all_scores = forward_dot_prod_scoring(
412 &self.scoring,
413 &hs_bf,
414 &text_out.text_memory_resized,
415 &text_out.attention_mask,
416 num_layers,
417 batch,
418 nq,
419 seq_len,
420 self.gguf_packed.as_ref(),
421 )?;
422 let last_scores =
423 all_scores[(num_layers - 1) * batch * nq..num_layers * batch * nq].to_vec();
424
425 let seg = forward_segmentation(
427 &mut self.seg_head,
428 &memory_bf,
429 &backbone_fpn,
430 &backbone_shapes,
431 &queries_last_bf,
432 &text_out.text_memory_resized,
433 &text_out.attention_mask,
434 batch,
435 h,
436 w,
437 nq,
438 seq_len,
439 self.gguf_packed.as_ref(),
440 )?;
441
442 Ok(Sam3ImagePrediction {
443 masks: seg.mask_pred,
444 mask_shape: vec![batch, nq, seg.h_out, seg.w_out],
445 boxes: boxes_xyxy,
446 boxes_shape: vec![batch, nq, 4],
447 scores: last_scores,
448 scores_shape: vec![batch, nq],
449 num_instances: nq,
450 h_out: resized_hw.0,
451 w_out: resized_hw.1,
452 })
453 }
454
455 #[allow(clippy::too_many_arguments)]
459 pub fn run_segmentation(
460 &mut self,
461 enc_memory_bf: &[f32],
462 backbone_fpn: &[Vec<f32>],
463 backbone_shapes: &[(usize, usize)],
464 obj_queries_last_bf: &[f32],
465 prompt_seq_first: &[f32],
466 prompt_kpm: &[u8],
467 batch: usize,
468 enc_h: usize,
469 enc_w: usize,
470 num_queries: usize,
471 seq_len: usize,
472 ) -> Result<Sam3SegmentationOutput> {
473 forward_segmentation(
474 &mut self.seg_head,
475 enc_memory_bf,
476 backbone_fpn,
477 backbone_shapes,
478 obj_queries_last_bf,
479 prompt_seq_first,
480 prompt_kpm,
481 batch,
482 enc_h,
483 enc_w,
484 num_queries,
485 seq_len,
486 self.gguf_packed.as_ref(),
487 )
488 }
489
490 #[allow(clippy::too_many_arguments)]
493 pub fn run_dot_prod_scoring(
494 &self,
495 hs_bf: &[f32],
496 prompt_seq_first: &[f32],
497 prompt_kpm: &[u8],
498 num_layers: usize,
499 batch: usize,
500 num_queries: usize,
501 seq_len: usize,
502 ) -> Result<Vec<f32>> {
503 forward_dot_prod_scoring(
504 &self.scoring,
505 hs_bf,
506 prompt_seq_first,
507 prompt_kpm,
508 num_layers,
509 batch,
510 num_queries,
511 seq_len,
512 self.gguf_packed.as_ref(),
513 )
514 }
515
516 #[allow(clippy::too_many_arguments)]
523 pub fn run_decoder(
524 &self,
525 memory: &[f32],
526 memory_pos: &[f32],
527 memory_text: &[f32],
528 text_attention_mask: &[u8],
529 batch: usize,
530 h: usize,
531 w: usize,
532 seq_len: usize,
533 ) -> Result<Sam3DecoderOutput> {
534 if rlx_ir::env::flag("RLX_SAM3_DECODER_HOST") {
535 return forward_decoder(
536 &self.decoder,
537 memory,
538 memory_pos,
539 memory_text,
540 text_attention_mask,
541 batch,
542 h,
543 w,
544 seq_len,
545 self.gguf_packed.as_ref(),
546 );
547 }
548 let dev = match rlx_ir::env::var("RLX_SAM3_DECODER_DEVICE").as_deref() {
549 Some("metal") => Device::Metal,
550 Some("mlx") => Device::Mlx,
551 Some("cuda") => Device::Cuda,
552 _ => self.device,
553 };
554 super::detector_decoder_ir::forward_decoder_ir_on_with_profile(
555 &self.decoder,
556 memory,
557 memory_pos,
558 memory_text,
559 text_attention_mask,
560 batch,
561 h,
562 w,
563 seq_len,
564 dev,
565 &self.compile_profile,
566 self.gguf_packed.as_ref(),
567 )
568 }
569
570 #[allow(clippy::too_many_arguments)]
574 pub fn run_encoder(
575 &self,
576 src_bchw: &[f32],
577 src_pos_bchw: &[f32],
578 prompt_seq_first: &[f32],
579 prompt_kpm: &[u8],
580 batch: usize,
581 src_h: usize,
582 src_w: usize,
583 prompt_len: usize,
584 ) -> Result<Vec<f32>> {
585 if rlx_ir::env::flag("RLX_SAM3_ENCODER_HOST") {
589 return forward_encoder(
590 &self.encoder,
591 src_bchw,
592 src_pos_bchw,
593 prompt_seq_first,
594 prompt_kpm,
595 batch,
596 src_h,
597 src_w,
598 prompt_len,
599 self.gguf_packed.as_ref(),
600 );
601 }
602 let dev = match rlx_ir::env::var("RLX_SAM3_ENCODER_DEVICE").as_deref() {
603 Some("metal") => Device::Metal,
604 Some("mlx") => Device::Mlx,
605 _ => Device::Cpu,
606 };
607 let _ = forward_encoder_ir; forward_encoder_ir_on_with_profile(
609 &self.encoder,
610 src_bchw,
611 src_pos_bchw,
612 prompt_seq_first,
613 prompt_kpm,
614 batch,
615 src_h,
616 src_w,
617 prompt_len,
618 dev,
619 &self.compile_profile,
620 self.gguf_packed.as_ref(),
621 )
622 }
623
624 pub fn encode_text_tokens(
627 &self,
628 tokens: &[u32],
629 batch: usize,
630 seq_len: usize,
631 ) -> Result<Sam3TextEncoded> {
632 encode_tokens(
633 &self.text,
634 tokens,
635 batch,
636 seq_len,
637 self.gguf_packed.as_ref(),
638 )
639 }
640
641 pub fn predict_neck(
642 &mut self,
643 image_u8: &[u8],
644 h_in: usize,
645 w_in: usize,
646 ) -> Result<Vec<super::neck::Sam3FeatureLevel>> {
647 let vision = self
648 .vision
649 .as_ref()
650 .context("SAM3 predict_neck requires native vision weights")?;
651 let (image_nchw, _) = preprocess_image(image_u8, h_in, w_in);
652 let vision_out = super::vision_encoder::encode_image_native(
653 vision,
654 self.gguf_packed.as_ref(),
655 &self.cfg.vit,
656 &image_nchw,
657 )?;
658 apply_neck_native(&mut self.neck, &vision_out)
659 }
660
661 pub fn patch_embed_image(
662 &self,
663 image_u8: &[u8],
664 h_in: usize,
665 w_in: usize,
666 ) -> Result<Sam3EncodedImage> {
667 let vision = self
668 .vision
669 .as_ref()
670 .context("SAM3 patch_embed_image requires native vision weights")?;
671 let (image_nchw, resized_hw) = preprocess_image(image_u8, h_in, w_in);
672 let patch_tokens = assemble_patch_tokens(&vision.pre, &image_nchw)?;
673 Ok(Sam3EncodedImage {
674 patch_tokens,
675 grid: vision.pre.grid,
676 embed_dim: vision.pre.embed_dim,
677 resized_hw,
678 })
679 }
680
681 pub fn predict_image(
682 &mut self,
683 image_u8: &[u8],
684 h_in: usize,
685 w_in: usize,
686 text_prompt: Option<&str>,
687 boxes: Option<&[f32]>,
688 points: Option<(&[f32], &[f32])>,
689 ) -> Result<Sam3ImagePrediction> {
690 self.predict_image_native(image_u8, h_in, w_in, text_prompt, boxes, points)
691 }
692
693 pub fn predict_video_frame(
694 &mut self,
695 state: &mut Sam3VideoState,
696 image_u8: &[u8],
697 h_in: usize,
698 w_in: usize,
699 text_prompt: Option<&str>,
700 ) -> Result<Sam3VideoFramePrediction> {
701 let pred = self.predict_image_native(image_u8, h_in, w_in, text_prompt, None, None)?;
702 Ok(tracker_forward_native(&self.tracker, state, pred))
703 }
704
705 fn predict_image_native(
706 &mut self,
707 image_u8: &[u8],
708 h_in: usize,
709 w_in: usize,
710 text_prompt: Option<&str>,
711 boxes: Option<&[f32]>,
712 points: Option<(&[f32], &[f32])>,
713 ) -> Result<Sam3ImagePrediction> {
714 ensure!(
715 image_u8.len() == h_in * w_in * 3,
716 "SAM3 image must be RGB u8 with len {} (got {})",
717 h_in * w_in * 3,
718 image_u8.len()
719 );
720 let vision = self
721 .vision
722 .as_ref()
723 .context("SAM3 predict_image requires native vision weights")?;
724 let (image_nchw, resized_hw) = preprocess_image(image_u8, h_in, w_in);
725 let vision_out = encode_image_native(
726 vision,
727 self.gguf_packed.as_ref(),
728 &self.cfg.vit,
729 &image_nchw,
730 )?;
731 let levels = apply_neck_native(&mut self.neck, &vision_out)?;
732 let text = encode_text_native(
733 &self.text,
734 &self.cfg.text,
735 text_prompt,
736 self.gguf_packed.as_ref(),
737 )?;
738 let geometry = encode_geometry_native(&self.geometry, boxes, points);
739 let det = detector_forward_native(
740 &self.detector,
741 &self.cfg.detector,
742 &levels,
743 &text,
744 &geometry,
745 )?;
746 Ok(segmentation_forward_native(
747 &self.seg,
748 &det,
749 resized_hw.0,
750 resized_hw.1,
751 ))
752 }
753}